diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -130,13 +130,10 @@ /// Return the type of the same shape (scalar, vector or tensor) containing i1. static Type getI1SameShape(Type type) { auto i1Type = IntegerType::get(type.getContext(), 1); - if (auto tensorType = llvm::dyn_cast(type)) - return RankedTensorType::get(tensorType.getShape(), i1Type); + if (auto shapedType = llvm::dyn_cast(type)) + return shapedType.cloneWith(std::nullopt, i1Type); if (llvm::isa(type)) return UnrankedTensorType::get(i1Type); - if (auto vectorType = llvm::dyn_cast(type)) - return VectorType::get(vectorType.getShape(), i1Type, - vectorType.getScalableDims()); return i1Type; } @@ -1150,9 +1147,21 @@ type_list()); } +/// Return false if both types are ranked tensor with mismatching encoding. +static bool hasSameEncoding(Type typeA, Type typeB) { + auto rankedTensorA = dyn_cast(typeA); + auto rankedTensorB = dyn_cast(typeB); + if (!rankedTensorA || !rankedTensorB) + return true; + return rankedTensorA.getEncoding() == rankedTensorB.getEncoding(); +} + static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) { - return inputs.size() == 1 && outputs.size() == 1 && - succeeded(verifyCompatibleShapes(inputs.front(), outputs.front())); + if (inputs.size() != 1 || outputs.size() != 1) + return false; + if (!hasSameEncoding(inputs.front(), outputs.front())) + return false; + return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front())); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir --- a/mlir/test/Dialect/Arith/invalid.mlir +++ b/mlir/test/Dialect/Arith/invalid.mlir @@ -206,6 +206,15 @@ // ----- +func.func @func_with_ops() { +^bb0: + %c = arith.constant dense<0> : tensor<42 x i32, "foo"> + // expected-error@+1 {{op failed to verify that result type has i1 element type and same shape as operands}} + %r = "arith.cmpi"(%c, %c) {predicate = 0} : (tensor<42 x i32, "foo">, tensor<42 x i32, "foo">) -> tensor<42 x i1, "bar"> +} + +// ----- + func.func @invalid_cmp_shape(%idx : () -> ()) { // expected-error@+1 {{'lhs' must be signless-integer-like, but got '() -> ()'}} %cmp = arith.cmpi eq, %idx, %idx : () -> () @@ -420,6 +429,14 @@ // ----- +func.func @fpext_vec_f32_to_i32(%arg0 : tensor<2xf32, "foo">) { + // expected-error@+1 {{op operand type 'tensor<2xf32, "foo">' and result type 'tensor<2xf64, "bar">' are cast incompatible}} + %0 = arith.extf %arg0 : tensor<2xf32, "foo"> to tensor<2xf64, "bar"> + return +} + +// ----- + func.func @fptrunc_f16_to_f32(%arg0 : f16) { // expected-error@+1 {{are cast incompatible}} %0 = arith.truncf %arg0 : f16 to f32 @@ -769,3 +786,12 @@ %0 = arith.select %arg0, %arg1, %arg2 : tensor, tensor<2x?xi64> return %0 : tensor<2x?xi64> } + +// ----- + +func.func @select_tensor_encoding( + %arg0 : tensor<8xi1, "bar">, %arg1 : tensor<8xi32, "foo">, %arg2 : tensor<8xi32, "foo">) -> tensor<8xi32, "foo"> { + // expected-error @+1 {{'arith.select' op expected condition type to have the same shape as the result type}} + %0 = arith.select %arg0, %arg1, %arg2 : tensor<8xi1, "bar">, tensor<8xi32, "foo"> + return %0 : tensor<8xi32, "foo"> +} diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir --- a/mlir/test/Dialect/Arith/ops.mlir +++ b/mlir/test/Dialect/Arith/ops.mlir @@ -637,6 +637,12 @@ return %0 : tensor<8x8xf64> } +// CHECK-LABEL: test_extf_tensor_encoding +func.func @test_extf_tensor_encoding(%arg0 : tensor<8x8xf32, "foo">) -> tensor<8x8xf64, "foo"> { + %0 = arith.extf %arg0 : tensor<8x8xf32, "foo"> to tensor<8x8xf64, "foo"> + return %0 : tensor<8x8xf64, "foo"> +} + // CHECK-LABEL: test_extf_vector func.func @test_extf_vector(%arg0 : vector<8xf32>) -> vector<8xf64> { %0 = arith.extf %arg0 : vector<8xf32> to vector<8xf64> @@ -950,6 +956,12 @@ return %0 : tensor<8x8xi1> } +// CHECK-LABEL: test_cmpi_tensor_encoding +func.func @test_cmpi_tensor_encoding(%arg0 : tensor<8x8xi64, "foo">, %arg1 : tensor<8x8xi64, "foo">) -> tensor<8x8xi1, "foo"> { + %0 = arith.cmpi slt, %arg0, %arg1 : tensor<8x8xi64, "foo"> + return %0 : tensor<8x8xi1, "foo"> +} + // CHECK-LABEL: test_cmpi_vector func.func @test_cmpi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi1> { %0 = arith.cmpi ult, %arg0, %arg1 : vector<8xi64> @@ -1103,3 +1115,18 @@ return } + +// CHECK-LABEL: @select_tensor +func.func @select_tensor(%arg0 : tensor<8xi1>, %arg1 : tensor<8xi32>, %arg2 : tensor<8xi32>) -> tensor<8xi32> { + // CHECK: = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xi1>, tensor<8xi32> + %0 = arith.select %arg0, %arg1, %arg2 : tensor<8xi1>, tensor<8xi32> + return %0 : tensor<8xi32> +} + +// CHECK-LABEL: @select_tensor_encoding +func.func @select_tensor_encoding( + %arg0 : tensor<8xi1, "foo">, %arg1 : tensor<8xi32, "foo">, %arg2 : tensor<8xi32, "foo">) -> tensor<8xi32, "foo"> { + // CHECK: = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xi1, "foo">, tensor<8xi32, "foo"> + %0 = arith.select %arg0, %arg1, %arg2 : tensor<8xi1, "foo">, tensor<8xi32, "foo"> + return %0 : tensor<8xi32, "foo"> +}