diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -1124,7 +1124,8 @@ /// Compare this range with another. template bool operator==(const OtherT &other) const { - return size() == std::distance(other.begin(), other.end()) && + return size() == + static_cast(std::distance(other.begin(), other.end())) && std::equal(begin(), end(), other.begin()); } template bool operator!=(const OtherT &other) const { diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1915,11 +1915,8 @@ // SelectOp //===----------------------------------------------------------------------===// -def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape, - AllTypesMatch<["true_value", "false_value", "result"]>, - TypesMatchWith<"condition type matches i1 equivalent of result type", - "result", "condition", - "getI1SameShape($_self)">]> { +def SelectOp : Std_Op<"select", [NoSideEffect, + AllTypesMatch<["true_value", "false_value", "result"]>]> { let summary = "select operation"; let description = [{ The `select` operation chooses one value based on a binary condition @@ -1930,7 +1927,8 @@ The operation applies to vectors and tensors elementwise given the _shape_ of all operands is identical. The choice is made for each element individually based on the value at the same position as the element in the - condition operand. + condition operand. If an i1 is provided as the condition, the entire vector + or tensor is chosen. The `select` operation combined with [`cmpi`](#stdcmpi-cmpiop) can be used to implement `min` and `max` with signed or unsigned comparison semantics. @@ -1944,9 +1942,11 @@ // Generic form of the same operation. %x = "std.select"(%cond, %true, %false) : (i1, i32, i32) -> i32 - // Vector selection is element-wise - %vx = "std.select"(%vcond, %vtrue, %vfalse) - : (vector<42xi1>, vector<42xf32>, vector<42xf32>) -> vector<42xf32> + // Element-wise vector selection. + %vx = std.select %vcond, %vtrue, %vfalse : vector<42xi1>, vector<42xf32> + + // Full vector selection. + %vx = std.select %cond, %vtrue, %vfalse : vector<42xf32> ``` }]; @@ -1954,7 +1954,6 @@ AnyType:$true_value, AnyType:$false_value); let results = (outs AnyType:$result); - let verifier = ?; let builders = [OpBuilder< "Builder *builder, OperationState &result, Value condition," @@ -1970,10 +1969,6 @@ }]; let hasFolder = 1; - - let assemblyFormat = [{ - $condition `,` $true_value `,` $false_value attr-dict `:` type($result) - }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -999,15 +999,6 @@ if (trueDest->getUniquePredecessor() != condbr.getOperation()->getBlock()) return failure(); - // TODO: ATM Tensor/Vector SelectOp requires that the condition has the same - // shape as the operands. We should relax that to allow an i1 to signify - // that everything is selected. - auto doesntSupportsScalarI1 = [](Type type) { - return type.isa() || type.isa(); - }; - if (llvm::any_of(trueOperands.getTypes(), doesntSupportsScalarI1)) - return failure(); - // Generate a select for any operands that differ between the two. SmallVector mergedOperands; mergedOperands.reserve(trueOperands.size()); @@ -1925,6 +1916,59 @@ return nullptr; } +static void print(OpAsmPrinter &p, SelectOp op) { + p << "select " << op.getOperands(); + p.printOptionalAttrDict(op.getAttrs()); + p << " : "; + if (ShapedType condType = op.getCondition().getType().dyn_cast()) + p << condType << ", "; + p << op.getType(); +} + +static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) { + Type conditionType, resultType; + SmallVector operands; + if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(resultType)) + return failure(); + + // Check for the explicit condition type if this is a masked tensor or vector. + if (succeeded(parser.parseOptionalComma())) { + conditionType = resultType; + if (parser.parseType(resultType)) + return failure(); + } else { + conditionType = parser.getBuilder().getI1Type(); + } + + result.addTypes(resultType); + return parser.resolveOperands(operands, + {conditionType, resultType, resultType}, + parser.getNameLoc(), result.operands); +} + +static LogicalResult verify(SelectOp op) { + Type conditionType = op.getCondition().getType(); + if (conditionType.isSignlessInteger(1)) + return success(); + + // If the result type is a vector or tensor, the type can be a mask with the + // same elements. + Type resultType = op.getType(); + if (!resultType.isa() && !resultType.isa()) + return op.emitOpError() + << "expected condition to be a signless i1, but got " + << conditionType; + Type shapedConditionType = getI1SameShape(resultType); + if (conditionType != shapedConditionType) + return op.emitOpError() + << "expected condition type to have the same shape " + "as the result type, expected " + << shapedConditionType << ", but got " << conditionType; + return success(); +} + //===----------------------------------------------------------------------===// // SignExtendIOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Standard/canonicalize-cf.mlir b/mlir/test/Dialect/Standard/canonicalize-cf.mlir --- a/mlir/test/Dialect/Standard/canonicalize-cf.mlir +++ b/mlir/test/Dialect/Standard/canonicalize-cf.mlir @@ -69,39 +69,18 @@ // CHECK-LABEL: func @cond_br_same_successor_insert_select( // CHECK-SAME: %[[COND:.*]]: i1, %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32 -func @cond_br_same_successor_insert_select(%cond : i1, %a : i32, %b : i32) -> i32 { +// CHECK-SAME: %[[ARG2:.*]]: tensor<2xi32>, %[[ARG3:.*]]: tensor<2xi32> +func @cond_br_same_successor_insert_select( + %cond : i1, %a : i32, %b : i32, %c : tensor<2xi32>, %d : tensor<2xi32> + ) -> (i32, tensor<2xi32>) { // CHECK: %[[RES:.*]] = select %[[COND]], %[[ARG0]], %[[ARG1]] - // CHECK: return %[[RES]] - - cond_br %cond, ^bb1(%a : i32), ^bb1(%b : i32) - -^bb1(%result : i32): - return %result : i32 -} - -/// Check that we don't generate a select if the type requires a splat. -/// TODO: SelectOp should allow for matching a vector/tensor with i1. - -// CHECK-LABEL: func @cond_br_same_successor_no_select_tensor( -func @cond_br_same_successor_no_select_tensor(%cond : i1, %a : tensor<2xi32>, - %b : tensor<2xi32>) -> tensor<2xi32>{ - // CHECK: cond_br - - cond_br %cond, ^bb1(%a : tensor<2xi32>), ^bb1(%b : tensor<2xi32>) - -^bb1(%result : tensor<2xi32>): - return %result : tensor<2xi32> -} - -// CHECK-LABEL: func @cond_br_same_successor_no_select_vector( -func @cond_br_same_successor_no_select_vector(%cond : i1, %a : vector<2xi32>, - %b : vector<2xi32>) -> vector<2xi32> { - // CHECK: cond_br + // CHECK: %[[RES2:.*]] = select %[[COND]], %[[ARG2]], %[[ARG3]] + // CHECK: return %[[RES]], %[[RES2]] - cond_br %cond, ^bb1(%a : vector<2xi32>), ^bb1(%b : vector<2xi32>) + cond_br %cond, ^bb1(%a, %c : i32, tensor<2xi32>), ^bb1(%b, %d : i32, tensor<2xi32>) -^bb1(%result : vector<2xi32>): - return %result : vector<2xi32> +^bb1(%result : i32, %result2 : tensor<2xi32>): + return %result, %result2 : i32, tensor<2xi32> } /// Test the compound folding of BranchOp and CondBranchOp. diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -141,17 +141,17 @@ // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %arg3, %arg3 : index %21 = select %18, %idx, %idx : index - // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_4, %cst_4 : tensor<42xi32> - %22 = select %19, %tci32, %tci32 : tensor<42 x i32> + // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_4, %cst_4 : tensor<42xi1>, tensor<42xi32> + %22 = select %19, %tci32, %tci32 : tensor<42 x i1>, tensor<42 x i32> - // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_5, %cst_5 : vector<42xi32> - %23 = select %20, %vci32, %vci32 : vector<42 x i32> + // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_5, %cst_5 : vector<42xi1>, vector<42xi32> + %23 = select %20, %vci32, %vci32 : vector<42 x i1>, vector<42 x i32> // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %arg3, %arg3 : index %24 = "std.select"(%18, %idx, %idx) : (i1, index, index) -> index // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_4, %cst_4 : tensor<42xi32> - %25 = "std.select"(%19, %tci32, %tci32) : (tensor<42 x i1>, tensor<42 x i32>, tensor<42 x i32>) -> tensor<42 x i32> + %25 = std.select %18, %tci32, %tci32 : tensor<42 x i32> // CHECK: %{{[0-9]+}} = divi_signed %arg2, %arg2 : i32 %26 = divi_signed %i, %i : i32 diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -281,18 +281,18 @@ // ----- -func @func_with_ops(i1, vector<42xi32>, vector<42xi32>) { -^bb0(%cond : i1, %t : vector<42xi32>, %f : vector<42xi32>): - // expected-error@+1 {{requires the same shape for all operands and results}} - %r = "std.select"(%cond, %t, %f) : (i1, vector<42xi32>, vector<42xi32>) -> vector<42xi32> +func @func_with_ops(vector<12xi1>, vector<42xi32>, vector<42xi32>) { +^bb0(%cond : vector<12xi1>, %t : vector<42xi32>, %f : vector<42xi32>): + // expected-error@+1 {{expected condition type to have the same shape as the result type, expected 'vector<42xi1>', but got 'vector<12xi1>'}} + %r = "std.select"(%cond, %t, %f) : (vector<12xi1>, vector<42xi32>, vector<42xi32>) -> vector<42xi32> } // ----- -func @func_with_ops(i1, tensor<42xi32>, tensor) { -^bb0(%cond : i1, %t : tensor<42xi32>, %f : tensor): - // expected-error@+1 {{ op requires the same shape for all operands and results}} - %r = "std.select"(%cond, %t, %f) : (i1, tensor<42xi32>, tensor) -> tensor<42xi32> +func @func_with_ops(tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) { +^bb0(%cond : tensor<12xi1>, %t : tensor<42xi32>, %f : tensor<42xi32>): + // expected-error@+1 {{expected condition type to have the same shape as the result type, expected 'tensor<42xi1>', but got 'tensor<12xi1>'}} + %r = "std.select"(%cond, %t, %f) : (tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32> } // -----