diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -302,8 +302,7 @@ return operation.emitError( "bitwidth emulation is not implemented yet on unsigned op"); } - rewriter.template replaceOpWithNewOp(operation, dstType, operands, - ArrayRef()); + rewriter.template replaceOpWithNewOp(operation, dstType, operands); return success(); } }; @@ -326,11 +325,11 @@ if (!dstType) return failure(); if (isBoolScalarOrVector(operands.front().getType())) { - rewriter.template replaceOpWithNewOp( - operation, dstType, operands, ArrayRef()); + rewriter.template replaceOpWithNewOp(operation, dstType, + operands); } else { - rewriter.template replaceOpWithNewOp( - operation, dstType, operands, ArrayRef()); + rewriter.template replaceOpWithNewOp(operation, dstType, + operands); } return success(); } @@ -487,8 +486,8 @@ // Then we can just erase this operation by forwarding its operand. rewriter.replaceOp(operation, operands.front()); } else { - rewriter.template replaceOpWithNewOp( - operation, dstType, operands, ArrayRef()); + rewriter.template replaceOpWithNewOp(operation, dstType, + operands); } return success(); } @@ -990,8 +989,7 @@ auto dstType = typeConverter.convertType(xorOp.getType()); if (!dstType) return failure(); - rewriter.replaceOpWithNewOp(xorOp, dstType, operands, - ArrayRef()); + rewriter.replaceOpWithNewOp(xorOp, dstType, operands); return success(); } diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -418,8 +418,7 @@ } if (auto *GV = dyn_cast(c)) return bEntry.create(UnknownLoc::get(context), - processGlobal(GV), - ArrayRef()); + processGlobal(GV)); if (auto *ce = dyn_cast(c)) { llvm::Instruction *i = ce->getAsInstruction(); @@ -727,7 +726,7 @@ if (!calledValue) return failure(); ops.insert(ops.begin(), calledValue); - op = b.create(loc, tys, ops, ArrayRef()); + op = b.create(loc, tys, ops); } if (!ci->getType()->isVoidTy()) v = op->getResult(0); @@ -809,7 +808,7 @@ Type type = processType(inst->getType()); if (!type) return failure(); - v = b.create(loc, type, ops, ArrayRef()); + v = b.create(loc, type, ops); return success(); } } diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -171,6 +171,49 @@ // CHECK-LABEL: class GOp : // CHECK: static ::mlir::LogicalResult inferReturnTypes +// Check default value for collective params builder +// Check that other builders are generated as well. +def NS_HCollectiveParamsOp : NS_Op<"op_collective_params", []> { + let arguments = (ins AnyType:$a); + let results = (outs AnyType:$b); +} + +// CHECK_LABEL: class NS_HCollectiveParamsOp : +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type b, ::mlir::Value a); +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::Value a); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}) + +// Check suppression of build method with single variadic arg and single +// variadic result +def NS_HCollectiveParamsSuppress0Op : NS_Op<"op_collective_suppress0", []> { + let arguments = (ins Variadic:$a); + let results = (outs Variadic:$b); +} + +// CHECK_LABEL: class NS_HCollectiveParamsSuppress0Op : +// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::mlir::ValueRange a); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + +// Check suppression of build method when single variadic arg and non variadic result +def NS_HCollectiveParamsSuppress1Op : NS_Op<"op_collective_suppress1", []> { + let arguments = (ins Variadic:$a); + let results = (outs I32:$b); +} + +// CHECK_LABEL: class NS_HCollectiveParamsSuppress1Op : +// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::mlir::ValueRange a); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + +// Check suppression of build method when single variadic arg and > 1 variadic result +def NS_HCollectiveParamsSuppress2Op : NS_Op<"op_collective_suppress2", [SameVariadicResultSize]> { + let arguments = (ins Variadic:$a); + let results = (outs Variadic:$b, Variadic:$c); +} +// CHECK_LABEL: class NS_HCollectiveParamsSuppress2Op : +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::llvm::ArrayRef<::mlir::Type> c, ::mlir::ValueRange a); +// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::mlir::ValueRange a); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + // Check that default builders can be suppressed. // --- diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -955,14 +955,46 @@ llvm_unreachable("unhandled TypeParamKind"); }; + // A separate arg param builder method will have a signature which is + // ambiguous with the collective params build method (with a default value + // for the last `attributes` parameter, generated in + // `genCollectiveParamBuilder` function below) if all of the below conditions + // are true: + // 1. getNumVariadicRegions must be 0 (otherwise the collective params build + // method ends with numRegions param, and we don't specify default value + // for attributes. + // 2. numArgs() must be 1 (if not, each arg gets a separate param in the build + // methods generated here) and the single arg must be a variable length + // non-attribute argument. + // 3. paramKind should be + // 3a. Collective, or + // 3b. Separate and there should be a single variadic result + // + // In that case, skip generating such ambiguous build methods here. + bool hasSingleVariadicResult = + op.getNumResults() == 1 && op.getResult(0).isVariadic(); + + bool hasSingleVariadicArg = + op.getNumArgs() == 1 && + op.getArg(0).is() && + op.getOperand(0).isVariadic(); + bool hasNoVariadicRegions = op.getNumVariadicRegions() == 0; + for (auto attrType : attrBuilderType) { - emit(attrType, TypeParamKind::Separate, /*inferType=*/false); + // Case 3b above. + if (!(hasNoVariadicRegions && hasSingleVariadicArg && + hasSingleVariadicResult)) + emit(attrType, TypeParamKind::Separate, /*inferType=*/false); if (canInferType(op)) emit(attrType, TypeParamKind::None, /*inferType=*/true); - // Emit separate arg build with collective type, unless there is only one - // variadic result, in which case the above would have already generated - // the same build method. - if (!(op.getNumResults() == 1 && op.getResult(0).isVariableLength())) + // The separate arg + collective param kind method will be: + // (a) Same as the separate arg + separate param kind method if there is + // only one variadic result. + // (b) Ambiguous with the collective params method under conditions in (3a) + // above. + // In either case, skip generating such build method. + if (!hasSingleVariadicResult && + !(hasNoVariadicRegions && hasSingleVariadicArg)) emit(attrType, TypeParamKind::Collective, /*inferType=*/false); } } @@ -1186,6 +1218,9 @@ "::llvm::ArrayRef<::mlir::NamedAttribute> attributes"; if (op.getNumVariadicRegions()) params += ", unsigned numRegions"; + else + // Provide default value for `attributes` if its the last parameter + params += " = {}"; auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static); auto &body = m.body();