diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -341,6 +341,7 @@ LogicalResult verifySameOperandsElementType(Operation *op); LogicalResult verifySameOperandsAndResultElementType(Operation *op); LogicalResult verifySameOperandsAndResultType(Operation *op); +LogicalResult verifySameOperandsAndResultRank(Operation *op); LogicalResult verifyResultsAreBoolLike(Operation *op); LogicalResult verifyResultsAreFloatLike(Operation *op); LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op); @@ -1117,6 +1118,17 @@ } }; +/// This class verifies that op has same ranks for all +/// operands and results types, if known. +template +class AllRanksMatchIfKnown + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameOperandsAndResultRank(op); + } +}; + /// This class verifies that any results of the specified op have a boolean /// type, a vector thereof, or a tensor thereof. template diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -369,4 +369,7 @@ // TODO: Change from hard coded to utilizing type inference trait. def SameOperandsAndResultType : NativeOpTrait<"SameOperandsAndResultType">; +// Op has the same ranks for all operands and results types, if known. +def AllRanksMatchIfKnown : NativeOpTrait<"AllRanksMatchIfKnown">; + #endif // MLIR_INFERTYPEOPINTERFACE diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -1084,6 +1084,51 @@ return success(); } +LogicalResult OpTrait::impl::verifySameOperandsAndResultRank(Operation *op) { + if (failed(verifyAtLeastNOperands(op, 1))) + return failure(); + + // delegate function that returns true if type is a shaped type with known + // rank + auto hasRank = [](const Type type) { + if (auto shaped_type = dyn_cast(type)) + return shaped_type.hasRank(); + + return false; + }; + + auto rankedOperandTypes = + llvm::make_filter_range(op->getOperandTypes(), hasRank); + auto rankedResultTypes = + llvm::make_filter_range(op->getResultTypes(), hasRank); + + // If all operands and results are unranked, then no further verification. + if (rankedOperandTypes.empty() && rankedResultTypes.empty()) + return success(); + + // delegate function that returns rank of shaped type with known rank + auto getRank = [](const Type type) { + return type.cast().getRank(); + }; + + auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin()) + : getRank(*rankedResultTypes.begin()); + + for (const auto type : rankedOperandTypes) { + if (rank != getRank(type)) { + return op->emitOpError("operands don't have matching ranks"); + } + } + + for (const auto type : rankedResultTypes) { + if (rank != getRank(type)) { + return op->emitOpError("result type has different rank than operands"); + } + } + + return success(); +} + LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { Block *block = op->getBlock(); // Verify that the operation is at the end of the respective parent block. diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -694,6 +694,12 @@ let results = (outs AnyShaped:$res); } +def OperandsAndResultsHaveSameRankIfKnown : + TEST_Op<"operands_and_result_have_same_rank_if_known", [AllRanksMatchIfKnown]> { + let arguments = (ins AnyShaped:$x, AnyShaped:$y); + let results = (outs AnyShaped:$res); +} + def OperandZeroAndResultHaveSameShape : TEST_Op<"operand0_and_result_have_same_shape", [AllShapesMatch<["x", "res"]>]> { diff --git a/mlir/test/mlir-tblgen/types.mlir b/mlir/test/mlir-tblgen/types.mlir --- a/mlir/test/mlir-tblgen/types.mlir +++ b/mlir/test/mlir-tblgen/types.mlir @@ -377,6 +377,33 @@ // ----- +// CHECK-LABEL: same_rank_if_known_success +func.func @same_rank_if_known_success(%t1xi : tensor<1xi32>, %t2xf : tensor<2xf32>, %m3xi : memref<3xi32>, %t1x2xf : tensor<1x2xf32>, %tuxi : tensor<*xi32>) { + %0 = "test.operands_and_result_have_same_rank_if_known"(%t1xi, %t2xf) : (tensor<1xi32>, tensor<2xf32>) -> (tensor<3xf64>) + %1 = "test.operands_and_result_have_same_rank_if_known"(%t1xi, %m3xi) : (tensor<1xi32>, memref<3xi32>) -> (tensor<3xi64>) + %3 = "test.operands_and_result_have_same_rank_if_known"(%tuxi, %t2xf) : (tensor<*xi32>, tensor<2xf32>) -> (tensor<2xf32>) + %4 = "test.operands_and_result_have_same_rank_if_known"(%t1x2xf, %tuxi) : (tensor<1x2xf32>, tensor<*xi32>) -> (tensor<1x2xf64>) + return +} + +// ----- + +func.func @same_rank_if_known_failure(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) { + // expected-error@+1 {{operands don't have matching ranks}} + %0 = "test.operands_and_result_have_same_rank_if_known"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> (tensor<*xf32>) + return +} + +// ----- + +func.func @same_rank_if_known_failure(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) { + // expected-error@+1 {{result type has different rank than operands}} + %0 = "test.operands_and_result_have_same_rank_if_known"(%arg1, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2x3xf32>) + return +} + +// ----- + // CHECK-LABEL: same_shape_success func.func @same_shape_success(%t2x3: tensor<2x3xi32>, %m2x3: memref<2x3xf32>, %v2x3 : vector<2x3xi32>, %t4x5 : tensor<4x5xi32>) { "test.operand0_and_result_have_same_shape"(%t2x3, %t4x5) : (tensor<2x3xi32>, tensor<4x5xi32>) -> (tensor<2x3xf32>)