diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -116,6 +116,13 @@ let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)"; } +// Just like `Arith_CompareOp` but also admits 0-D vectors. Introduced +// temporarily to allow gradual transition to 0-D vectors. +class Arith_CompareOpOfAnyRank traits = []> : + Arith_CompareOp { + let results = (outs BoolLikeOfAnyRank:$result); +} + //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// @@ -284,7 +291,7 @@ def Arith_CeilDivUIOp : Arith_IntBinaryOp<"ceildivui"> { let summary = "unsigned ceil integer division operation"; let description = [{ - Unsigned integer division. Rounds towards positive infinity. Treats the + Unsigned integer division. Rounds towards positive infinity. Treats the leading bit as the most significant, i.e. for `i16` given two's complement representation, `6 / -2 = 6 / (2^16 - 2) = 1`. @@ -990,7 +997,7 @@ // CmpIOp //===----------------------------------------------------------------------===// -def Arith_CmpIOp : Arith_CompareOp<"cmpi"> { +def Arith_CmpIOp : Arith_CompareOpOfAnyRank<"cmpi"> { let summary = "integer comparison operation"; let description = [{ The `cmpi` operation is a generic comparison for integer-like types. Its two @@ -1057,8 +1064,8 @@ }]; let arguments = (ins Arith_CmpIPredicateAttr:$predicate, - SignlessIntegerLike:$lhs, - SignlessIntegerLike:$rhs); + SignlessIntegerLikeOfAnyRank:$lhs, + SignlessIntegerLikeOfAnyRank:$rhs); let builders = [ OpBuilder<(ins "CmpIPredicate":$predicate, "Value":$lhs, "Value":$rhs), [{ diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -213,6 +213,7 @@ CPred<"$_self.cast<::mlir::VectorType>().getRank() > 0">]>; // Temporary vector type clone that allows gradual transition to 0-D vectors. +// TODO: Remove this when all ops support 0-D vectors. def IsVectorOfAnyRankTypePred : CPred<"$_self.isa<::mlir::VectorType>()">; // Whether a type is a TensorType. @@ -603,7 +604,9 @@ class VectorOf allowedTypes> : ShapedContainerType; + // Temporary vector type clone that allows gradual transition to 0-D vectors. +// TODO: Remove this when all ops support 0-D vectors. class VectorOfAnyRankOf allowedTypes> : ShapedContainerType; @@ -835,6 +838,14 @@ TensorOf<[I1]>.predicate]>, "bool-like">; +// Temporary constraint to allow gradual transition to supporting 0-D vectors. +// TODO: Remove this when all ops support 0-D vectors. +def BoolLikeOfAnyRank : TypeConstraint.predicate, + TensorOf<[I1]>.predicate]>, + "bool-like">; + // Type constraint for signless-integer-like types: signless integers, indices, // vectors of signless integers or indices, tensors of signless integers. def SignlessIntegerLike : TypeConstraint.predicate]>, "signless-integer-like">; +// Temporary constraint to allow gradual transition to supporting 0-D vectors. +// TODO: Remove this when all ops support 0-D vectors. +def SignlessIntegerLikeOfAnyRank : TypeConstraint.predicate, + TensorOf<[AnySignlessIntegerOrIndex]>.predicate]>, + "signless-integer-like">; + // Type constraint for float-like types: floats, vectors or tensors thereof. def FloatLike : TypeConstraint.predicate, TensorOf<[AnyFloat]>.predicate]>, diff --git a/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir --- a/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir @@ -352,6 +352,17 @@ // ----- +// CHECK-LABEL: func @cmpi_0dvector( +func @cmpi_0dvector(%arg0 : vector, %arg1 : vector) { + // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast + // CHECK: %[[ARG1:.*]] = builtin.unrealized_conversion_cast + // CHECK: %[[CMP:.*]] = llvm.icmp "ult" %[[ARG0]], %[[ARG1]] : vector<1xi32> + %0 = arith.cmpi ult, %arg0, %arg1 : vector + std.return +} + +// ----- + // CHECK-LABEL: func @cmpi_2dvector( func @cmpi_2dvector(%arg0 : vector<4x3xi32>, %arg1 : vector<4x3xi32>) { // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast diff --git a/mlir/test/Dialect/Arithmetic/ops.mlir b/mlir/test/Dialect/Arithmetic/ops.mlir --- a/mlir/test/Dialect/Arithmetic/ops.mlir +++ b/mlir/test/Dialect/Arithmetic/ops.mlir @@ -631,6 +631,12 @@ return %0 : vector<8xi1> } +// CHECK-LABEL: test_cmpi_vector_0d +func @test_cmpi_vector_0d(%arg0 : vector, %arg1 : vector) -> vector { + %0 = arith.cmpi ult, %arg0, %arg1 : vector + return %0 : vector +} + // CHECK-LABEL: test_cmpf func @test_cmpf(%arg0 : f64, %arg1 : f64) -> i1 { %0 = arith.cmpf oeq, %arg0, %arg1 : f64 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir @@ -67,7 +67,6 @@ return } - func @constant_mask_0d() { %1 = vector.constant_mask [0] : vector // CHECK: ( 0 ) @@ -78,6 +77,22 @@ return } +func @arith_cmpi_0d(%smaller : vector, %bigger : vector) { + %0 = arith.cmpi ult, %smaller, %bigger : vector + // CHECK: ( 1 ) + vector.print %0: vector + + %1 = arith.cmpi ugt, %smaller, %bigger : vector + // CHECK: ( 0 ) + vector.print %1: vector + + %2 = arith.cmpi eq, %smaller, %bigger : vector + // CHECK: ( 0 ) + vector.print %2: vector + + return +} + func @entry() { %0 = arith.constant 42.0 : f32 %1 = arith.constant dense<0.0> : vector @@ -96,5 +111,9 @@ call @bitcast_0d() : () -> () call @constant_mask_0d() : () -> () + %smaller = arith.constant dense<42> : vector + %bigger = arith.constant dense<4242> : vector + call @arith_cmpi_0d(%smaller, %bigger) : (vector, vector) -> () + return }