diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td --- a/mlir/include/mlir/Interfaces/VectorInterfaces.td +++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td @@ -114,6 +114,21 @@ /*methodBody=*/"return $_op.permutation_map();" /*defaultImplementation=*/ >, + InterfaceMethod< + /*desc=*/[{ Returns true if at least one of the dimensions in the + permutation map is a broadcast.}], + /*retTy=*/"bool", + /*methodName=*/"hasBroadcastDim", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return llvm::any_of( + $_op.permutation_map().getResults(), + [](AffineExpr e) { + return e.isa() && + e.dyn_cast().getValue() == 0; }); + }] + >, InterfaceMethod< /*desc=*/"Return the `in_bounds` boolean ArrayAttr.", /*retTy=*/"Optional", diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -2683,6 +2683,11 @@ if (llvm::size(op.indices()) != shapedType.getRank()) return op.emitOpError("requires ") << shapedType.getRank() << " indices"; + // We do not allow broadcast dimensions on TransferWriteOps for the moment, + // as the semantics is unclear. This can be revisited later if necessary. + if (op.hasBroadcastDim()) + return op.emitOpError("should not have broadcast dimensions"); + if (failed(verifyTransferOp(op.getOperation(), shapedType, vectorType, permutationMap, op.in_bounds() ? *op.in_bounds() : ArrayAttr()))) diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -461,6 +461,17 @@ // ----- +func @test_vector.transfer_write(%arg0: memref, %arg1: vector<7xf32>) { + %c3 = constant 3 : index + %cst = constant 3.0 : f32 + // expected-error@+1 {{should not have broadcast dimensions}} + vector.transfer_write %arg1, %arg0[%c3] + {permutation_map = affine_map<(d0) -> (0)>} + : vector<7xf32>, memref +} + +// ----- + func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) { // expected-error@+1 {{expected offsets of same size as destination vector rank}} %1 = vector.insert_strided_slice %a, %b {offsets = [100], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>