diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1657,9 +1657,8 @@ ]> { let summary = "tensor splat or broadcast operation"; let description = [{ - Broadcast the operand to all elements of the result tensor. The operand is - required to be of integer/index/float type, and the result tensor must be - statically shaped. + Broadcast the operand to all elements of the result tensor. The result + tensor must be statically shaped. Example: @@ -1680,8 +1679,7 @@ ``` }]; - let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat], - "integer/index/float type">:$input); + let arguments = (ins AnyType:$input); let results = (outs AnyStaticShapeTensor:$aggregate); let builders = [ diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -400,14 +400,6 @@ // ----- -func.func @invalid_splat(%v : vector<8xf32>) { - // expected-error@+1 {{must be integer/index/float type}} - %w = tensor.splat %v : tensor<8xvector<8xf32>> - return -} - -// ----- - func.func @gather_empty_dims( %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) { // expected-error@+1 {{gather_dims must be non-empty}}