Index: mlir/include/mlir/Dialect/Vector/VectorOps.td =================================================================== --- mlir/include/mlir/Dialect/Vector/VectorOps.td +++ mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1301,6 +1301,7 @@ "ValueRange":$indices, CArg<"ArrayRef", "{}">:$maybeMasked)> ]; + let hasCanonicalizer = 1; let hasFolder = 1; } @@ -1395,6 +1396,7 @@ "AffineMap":$permutationMap, "ArrayAttr":$masked)>, ]; + let hasCanonicalizer = 1; let hasFolder = 1; } Index: mlir/lib/Dialect/Vector/VectorOps.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorOps.cpp +++ mlir/lib/Dialect/Vector/VectorOps.cpp @@ -2490,6 +2490,118 @@ SideEffects::DefaultResource::get()); } +/// Returns true if 'map' is a suffix of an identity affine map except for +/// broadcasted dimensions (which are indicated by 0 values in the result). If +/// `broadcastedDims` is not null, it will be populated with the indices of the +/// broadcasted dimensions. +/// Example: affine_map<(d0, d1, d2, d3, d4) -> (0, d2, 0, d4)> +/// `broadcastedDims` will contain [0, 2] (the result indices equal to 0) +static bool isMinorIdentityWithBroadcasting( + AffineMap map, SmallVectorImpl *broadcastedDims = nullptr) { + if (broadcastedDims) + broadcastedDims->clear(); + if (map.getNumDims() < map.getNumResults()) + return false; + unsigned suffixStart = map.getNumDims() - map.getNumResults(); + for (auto idxAndExpr : llvm::enumerate(map.getResults())) { + unsigned resIdx = idxAndExpr.index(); + AffineExpr expr = idxAndExpr.value(); + if (auto constExpr = expr.dyn_cast()) { + // Each result must be either a constant 0 (broadcasted dimension) + if (constExpr.getValue() != 0) + return false; + if (broadcastedDims) + broadcastedDims->push_back(resIdx); + } else if (auto dimExpr = expr.dyn_cast()) { + // or the input dimension corresponding to this result position. + if (dimExpr.getPosition() != suffixStart + resIdx) + return false; + } else { + return false; + } + } + return true; +} + +namespace { +/// Progressive lowering of transfer_read. This pattern supports lowering of +/// `vector.transfer_read` to a combination of `vector.load` and +/// `vector.broadcast` if all of the following hold: +/// - The op reads from a memref with the default layout. +/// - Masking is not required. +/// - If the memref's element type is a vector type then it coincides with the +/// result type. +/// - The permutation map doesn't perform permutation (broadcasting is allowed). +struct TransferReadToVectorLoadLowering final + : public OpRewritePattern { + TransferReadToVectorLoadLowering(MLIRContext *context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(vector::TransferReadOp read, + PatternRewriter &rewriter) const override { + SmallVector broadcastedDims; + // TODO: Support permutations. + if (!isMinorIdentityWithBroadcasting(read.permutation_map(), + &broadcastedDims)) + return failure(); + auto memRefType = read.getShapedType().dyn_cast(); + if (!memRefType) + return failure(); + + // If there is broadcasting involved then we first load the unbroadcasted + // vector, and then broadcast it with `vector.broadcast`. + ArrayRef vectorShape = read.getVectorType().getShape(); + SmallVector unbroadcastedVectorShape(vectorShape.begin(), + vectorShape.end()); + for (unsigned i : broadcastedDims) + unbroadcastedVectorShape[i] = 1; + VectorType unbroadcastedVectorType = VectorType::get( + unbroadcastedVectorShape, read.getVectorType().getElementType()); + + // `vector.load` supports vector types as memref's elements only when the + // resulting vector type is the same as the element type. + if (memRefType.getElementType().isa() && + memRefType.getElementType() != unbroadcastedVectorType) + return failure(); + // Only the default layout is supported by `vector.load`. + // TODO: Support non-default layouts. + if (!memRefType.getAffineMaps().empty()) + return failure(); + // TODO: When masking is required, we can create a MaskedLoadOp + if (read.hasMaskedDim()) + return failure(); + + Operation *loadOp; + if (!broadcastedDims.empty() && + unbroadcastedVectorType.getNumElements() == 1) { + // If broadcasting is required and the number of loaded elements is 1 then + // we can create `std.load` instead of `vector.load`. + loadOp = rewriter.create(read.getLoc(), read.source(), + read.indices()); + } else { + // Otherwise create `vector.load`. + loadOp = rewriter.create(read.getLoc(), + unbroadcastedVectorType, + read.source(), read.indices()); + } + + // Insert a broadcasting op if required. + if (!broadcastedDims.empty()) { + rewriter.replaceOpWithNewOp( + read, read.getVectorType(), loadOp->getResult(0)); + } else { + rewriter.replaceOp(read, loadOp->getResult(0)); + } + + return success(); + } +}; +} // namespace + +void TransferReadOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // TransferWriteOp //===----------------------------------------------------------------------===// @@ -2612,6 +2724,51 @@ SideEffects::DefaultResource::get()); } +namespace { +/// Progressive lowering of transfer_write. This pattern supports lowering of +/// `vector.transfer_write` to `vector.store` if all of the following hold: +/// - The op writes to a memref with the default layout. +/// - Masking is not required. +/// - If the memref's element type is a vector type then it coincides with the +/// type of the written value. +/// - The permutation map is the minor identity map (neither permutation nor +/// broadcasting is allowed). +struct TransferWriteToVectorStoreLowering final + : public OpRewritePattern { + TransferWriteToVectorStoreLowering(MLIRContext *context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(vector::TransferWriteOp write, + PatternRewriter &rewriter) const override { + // TODO: Support non-minor-identity maps + if (!write.permutation_map().isMinorIdentity()) + return failure(); + auto memRefType = write.getShapedType().dyn_cast(); + if (!memRefType) + return failure(); + // `vector.store` supports vector types as memref's elements only when the + // type of the vector value being written is the same as the element type. + if (memRefType.getElementType().isa() && + memRefType.getElementType() != write.getVectorType()) + return failure(); + // Only the default layout is supported by `vector.store`. + // TODO: Support non-default layouts. + if (!memRefType.getAffineMaps().empty()) + return failure(); + // TODO: When masking is required, we can create a MaskedStoreOp + if (write.hasMaskedDim()) + return failure(); + rewriter.replaceOpWithNewOp( + write, write.vector(), write.source(), write.indices()); + return success(); + } +}; +} // namespace + +void TransferWriteOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // LoadOp //===----------------------------------------------------------------------===// Index: mlir/test/Dialect/Vector/canonicalize.mlir =================================================================== --- mlir/test/Dialect/Vector/canonicalize.mlir +++ mlir/test/Dialect/Vector/canonicalize.mlir @@ -252,15 +252,15 @@ // ----- // CHECK-LABEL: cast_transfers -func @cast_transfers(%A: memref<4x8xf32>) -> (vector<4x8xf32>) { +func @cast_transfers(%A: memref<2x8xf32>) -> (vector<4x8xf32>) { %c0 = constant 0 : index %f0 = constant 0.0 : f32 - %0 = memref_cast %A : memref<4x8xf32> to memref + %0 = memref_cast %A : memref<2x8xf32> to memref - // CHECK: vector.transfer_read %{{.*}} {masked = [false, false]} : memref<4x8xf32>, vector<4x8xf32> + // CHECK: vector.transfer_read %{{.*}} {masked = [true, false]} : memref<2x8xf32>, vector<4x8xf32> %1 = vector.transfer_read %0[%c0, %c0], %f0 : memref, vector<4x8xf32> - // CHECK: vector.transfer_write %{{.*}} {masked = [false, false]} : vector<4x8xf32>, memref<4x8xf32> + // CHECK: vector.transfer_write %{{.*}} {masked = [true, false]} : vector<4x8xf32>, memref<2x8xf32> vector.transfer_write %1, %0[%c0, %c0] : vector<4x8xf32>, memref return %1 : vector<4x8xf32> } @@ -770,3 +770,212 @@ return %1, %i8_1: vector<2x4xf32>, vector<2x4xi8> } +// ----- + +// transfer_read/write are lowered to vector.load/store +// CHECK-LABEL: func @transfer_to_load( +// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, +// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> { +// CHECK-NEXT: %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<4xf32> +// CHECK-NEXT: vector.store %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<4xf32> +// CHECK-NEXT: return %[[RES]] : vector<4xf32> +// CHECK-NEXT: } + +func @transfer_to_load(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> { + %cf0 = constant 0.0 : f32 + %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false]} : memref<8x8xf32>, vector<4xf32> + vector.transfer_write %res, %mem[%i, %i] {masked = [false]} : vector<4xf32>, memref<8x8xf32> + return %res : vector<4xf32> +} + +// ----- + +// n-D results are also supported. +// CHECK-LABEL: func @transfer_2D( +// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, +// CHECK-SAME: %[[IDX:.*]]: index) -> vector<2x4xf32> { +// CHECK-NEXT: %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<2x4xf32> +// CHECK-NEXT: vector.store %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<2x4xf32> +// CHECK-NEXT: return %[[RES]] : vector<2x4xf32> +// CHECK-NEXT: } + +func @transfer_2D(%mem : memref<8x8xf32>, %i : index) -> vector<2x4xf32> { + %cf0 = constant 0.0 : f32 + %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false, false]} : memref<8x8xf32>, vector<2x4xf32> + vector.transfer_write %res, %mem[%i, %i] {masked = [false, false]} : vector<2x4xf32>, memref<8x8xf32> + return %res : vector<2x4xf32> +} + +// ----- + +// Vector element types are supported when the result has the same type. +// CHECK-LABEL: func @transfer_vector_element( +// CHECK-SAME: %[[MEM:.*]]: memref<8x8xvector<2x4xf32>>, +// CHECK-SAME: %[[IDX:.*]]: index) -> vector<2x4xf32> { +// CHECK-NEXT: %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xvector<2x4xf32>>, vector<2x4xf32> +// CHECK-NEXT: vector.store %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xvector<2x4xf32>>, vector<2x4xf32> +// CHECK-NEXT: return %[[RES]] : vector<2x4xf32> +// CHECK-NEXT: } + +func @transfer_vector_element(%mem : memref<8x8xvector<2x4xf32>>, %i : index) -> vector<2x4xf32> { + %cf0 = constant dense<0.0> : vector<2x4xf32> + %res = vector.transfer_read %mem[%i, %i], %cf0 : memref<8x8xvector<2x4xf32>>, vector<2x4xf32> + vector.transfer_write %res, %mem[%i, %i] : vector<2x4xf32>, memref<8x8xvector<2x4xf32>> + return %res : vector<2x4xf32> +} + +// ----- + +// TODO: Vector element types are not supported yet when the result has a +// different type. +// CHECK-LABEL: func @transfer_vector_element_different_types( +// CHECK-SAME: %[[MEM:.*]]: memref<8x8xvector<2x4xf32>>, +// CHECK-SAME: %[[IDX:.*]]: index) -> vector<1x2x4xf32> { +// CHECK-NEXT: %[[CF0:.*]] = constant dense<0.000000e+00> : vector<2x4xf32> +// CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {masked = [false]} : memref<8x8xvector<2x4xf32>>, vector<1x2x4xf32> +// CHECK-NEXT: vector.transfer_write %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] {masked = [false]} : vector<1x2x4xf32>, memref<8x8xvector<2x4xf32>> +// CHECK-NEXT: return %[[RES]] : vector<1x2x4xf32> +// CHECK-NEXT: } + +func @transfer_vector_element_different_types(%mem : memref<8x8xvector<2x4xf32>>, %i : index) -> vector<1x2x4xf32> { + %cf0 = constant dense<0.0> : vector<2x4xf32> + %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false]} : memref<8x8xvector<2x4xf32>>, vector<1x2x4xf32> + vector.transfer_write %res, %mem[%i, %i] {masked = [false]} : vector<1x2x4xf32>, memref<8x8xvector<2x4xf32>> + return %res : vector<1x2x4xf32> +} + +// ----- + +// TODO: transfer_read/write cannot be lowered because there is an unmasked +// dimension. +// CHECK-LABEL: func @transfer_2D_masked( +// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, +// CHECK-SAME: %[[IDX:.*]]: index) -> vector<2x4xf32> { +// CHECK-NEXT: %[[CF0:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {masked = [false, true]} : memref<8x8xf32>, vector<2x4xf32> +// CHECK-NEXT: vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] {masked = [true, false]} : vector<2x4xf32>, memref<8x8xf32> +// CHECK-NEXT: return %[[RES]] : vector<2x4xf32> +// CHECK-NEXT: } + +func @transfer_2D_masked(%mem : memref<8x8xf32>, %i : index) -> vector<2x4xf32> { + %cf0 = constant 0.0 : f32 + %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false, true]} : memref<8x8xf32>, vector<2x4xf32> + vector.transfer_write %res, %mem[%i, %i] {masked = [true, false]} : vector<2x4xf32>, memref<8x8xf32> + return %res : vector<2x4xf32> +} + +// ----- + +// TODO: transfer_read/write cannot be lowered because they are masked. +// CHECK-LABEL: func @transfer_masked( +// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, +// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> { +// CHECK-NEXT: %[[CF0:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] : memref<8x8xf32>, vector<4xf32> +// CHECK-NEXT: vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] : vector<4xf32>, memref<8x8xf32> +// CHECK-NEXT: return %[[RES]] : vector<4xf32> +// CHECK-NEXT: } + +func @transfer_masked(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> { + %cf0 = constant 0.0 : f32 + %res = vector.transfer_read %mem[%i, %i], %cf0 : memref<8x8xf32>, vector<4xf32> + vector.transfer_write %res, %mem[%i, %i] : vector<4xf32>, memref<8x8xf32> + return %res : vector<4xf32> +} + +// ----- + +// TODO: transfer_read/write cannot be lowered to vector.load/store because the +// memref has a non-default layout. +// CHECK-LABEL: func @transfer_nondefault_layout( +// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32, #{{.*}}>, +// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> { +// CHECK-NEXT: %[[CF0:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {masked = [false]} : memref<8x8xf32, #{{.*}}>, vector<4xf32> +// CHECK-NEXT: vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] {masked = [false]} : vector<4xf32>, memref<8x8xf32, #{{.*}}> +// CHECK-NEXT: return %[[RES]] : vector<4xf32> +// CHECK-NEXT: } + +#layout = affine_map<(d0, d1) -> (d0*16 + d1)> +func @transfer_nondefault_layout(%mem : memref<8x8xf32, #layout>, %i : index) -> vector<4xf32> { + %cf0 = constant 0.0 : f32 + %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false]} : memref<8x8xf32, #layout>, vector<4xf32> + vector.transfer_write %res, %mem[%i, %i] {masked = [false]} : vector<4xf32>, memref<8x8xf32, #layout> + return %res : vector<4xf32> +} + +// ----- + +// TODO: transfer_read/write cannot be lowered to vector.load/store yet when the +// permutation map is not the minor identity map (up to broadcasting). +// CHECK-LABEL: func @transfer_perm_map( +// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, +// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> { +// CHECK-NEXT: %[[CF0:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {masked = [false], permutation_map = #{{.*}}} : memref<8x8xf32>, vector<4xf32> +// CHECK-NEXT: vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] {masked = [false], permutation_map = #{{.*}}} : vector<4xf32>, memref<8x8xf32> +// CHECK-NEXT: return %[[RES]] : vector<4xf32> +// CHECK-NEXT: } + +func @transfer_perm_map(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> { + %cf0 = constant 0.0 : f32 + %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false], permutation_map = affine_map<(d0, d1) -> (d0)>} : memref<8x8xf32>, vector<4xf32> + vector.transfer_write %res, %mem[%i, %i] {masked = [false], permutation_map = affine_map<(d0, d1) -> (d0)>} : vector<4xf32>, memref<8x8xf32> + return %res : vector<4xf32> +} + +// ----- + +// Lowering of transfer_read with broadcasting is supported (note that a `load` +// is generated instead of a `vector.load`). +// CHECK-LABEL: func @transfer_broadcasting( +// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, +// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> { +// CHECK-NEXT: %[[LOAD:.*]] = load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32> +// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<4xf32> +// CHECK-NEXT: return %[[RES]] : vector<4xf32> +// CHECK-NEXT: } + +#broadcast = affine_map<(d0, d1) -> (0)> +func @transfer_broadcasting(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> { + %cf0 = constant 0.0 : f32 + %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false], permutation_map = #broadcast} : memref<8x8xf32>, vector<4xf32> + return %res : vector<4xf32> +} + +// ----- + +// An example with two broadcasted dimensions. +// CHECK-LABEL: func @transfer_broadcasting_2D( +// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, +// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4x4xf32> { +// CHECK-NEXT: %[[LOAD:.*]] = load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32> +// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<4x4xf32> +// CHECK-NEXT: return %[[RES]] : vector<4x4xf32> +// CHECK-NEXT: } + +#broadcast = affine_map<(d0, d1) -> (0, 0)> +func @transfer_broadcasting_2D(%mem : memref<8x8xf32>, %i : index) -> vector<4x4xf32> { + %cf0 = constant 0.0 : f32 + %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false, false], permutation_map = #broadcast} : memref<8x8xf32>, vector<4x4xf32> + return %res : vector<4x4xf32> +} + +// ----- + +// More complex broadcasting case (here a `vector.load` is generated). +// CHECK-LABEL: func @transfer_broadcasting_complex( +// CHECK-SAME: %[[MEM:.*]]: memref<10x20x30x8x8xf32>, +// CHECK-SAME: %[[IDX:.*]]: index) -> vector<3x2x4x5xf32> { +// CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]] : memref<10x20x30x8x8xf32>, vector<3x1x1x5xf32> +// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : vector<3x1x1x5xf32> to vector<3x2x4x5xf32> +// CHECK-NEXT: return %[[RES]] : vector<3x2x4x5xf32> +// CHECK-NEXT: } + +#broadcast = affine_map<(d0, d1, d2, d3, d4) -> (d1, 0, 0, d4)> +func @transfer_broadcasting_complex(%mem : memref<10x20x30x8x8xf32>, %i : index) -> vector<3x2x4x5xf32> { + %cf0 = constant 0.0 : f32 + %res = vector.transfer_read %mem[%i, %i, %i, %i, %i], %cf0 {masked = [false, false, false, false], permutation_map = #broadcast} : memref<10x20x30x8x8xf32>, vector<3x2x4x5xf32> + return %res : vector<3x2x4x5xf32> +} +