diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -426,23 +426,23 @@ FailureOr createOperationExpr(SMRange loc, const ast::OpNameDecl *name, OpResultTypeContext resultTypeContext, - MutableArrayRef operands, + SmallVectorImpl &operands, MutableArrayRef attributes, - MutableArrayRef results); + SmallVectorImpl &results); LogicalResult validateOperationOperands(SMRange loc, Optional name, const ods::Operation *odsOp, - MutableArrayRef operands); + SmallVectorImpl &operands); LogicalResult validateOperationResults(SMRange loc, Optional name, const ods::Operation *odsOp, - MutableArrayRef results); + SmallVectorImpl &results); void checkOperationResultTypeInferrence(SMRange loc, StringRef name, const ods::Operation *odsOp); LogicalResult validateOperationOperandsOrResults( StringRef groupName, SMRange loc, Optional odsOpLoc, - Optional name, MutableArrayRef values, + Optional name, SmallVectorImpl &values, ArrayRef odsValues, ast::Type singleTy, - ast::Type rangeTy); + ast::RangeType rangeTy); FailureOr createTupleExpr(SMRange loc, ArrayRef elements, ArrayRef elementNames); @@ -2851,9 +2851,9 @@ FailureOr Parser::createOperationExpr( SMRange loc, const ast::OpNameDecl *name, OpResultTypeContext resultTypeContext, - MutableArrayRef operands, + SmallVectorImpl &operands, MutableArrayRef attributes, - MutableArrayRef results) { + SmallVectorImpl &results) { Optional opNameRef = name->getName(); const ods::Operation *odsOp = lookupODSOperation(opNameRef); @@ -2896,7 +2896,7 @@ LogicalResult Parser::validateOperationOperands(SMRange loc, Optional name, const ods::Operation *odsOp, - MutableArrayRef operands) { + SmallVectorImpl &operands) { return validateOperationOperandsOrResults( "operand", loc, odsOp ? odsOp->getLoc() : Optional(), name, operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy, @@ -2906,7 +2906,7 @@ LogicalResult Parser::validateOperationResults(SMRange loc, Optional name, const ods::Operation *odsOp, - MutableArrayRef results) { + SmallVectorImpl &results) { return validateOperationOperandsOrResults( "result", loc, odsOp ? odsOp->getLoc() : Optional(), name, results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy); @@ -2956,9 +2956,9 @@ LogicalResult Parser::validateOperationOperandsOrResults( StringRef groupName, SMRange loc, Optional odsOpLoc, - Optional name, MutableArrayRef values, + Optional name, SmallVectorImpl &values, ArrayRef odsValues, ast::Type singleTy, - ast::Type rangeTy) { + ast::RangeType rangeTy) { // All operation types accept a single range parameter. if (values.size() == 1) { if (failed(convertExpressionTo(values[0], rangeTy))) @@ -2969,14 +2969,56 @@ /// If the operation has ODS information, we can more accurately verify the /// values. if (odsOpLoc) { - if (odsValues.size() != values.size()) { + auto emitSizeMismatchError = [&] { return emitErrorAndNote( loc, llvm::formatv("invalid number of {0} groups for `{1}`; expected " "{2}, but got {3}", groupName, *name, odsValues.size(), values.size()), *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name)); + }; + + // Handle the case where no values were provided. + if (values.empty()) { + // If we don't expect any on the ODS side, we are done. + if (odsValues.empty()) + return success(); + + // If we do, check if we actually need to provide values (i.e. if any of + // the values are actually required). + unsigned numVariadic = 0; + for (const auto &odsValue : odsValues) { + if (!odsValue.isVariableLength()) + return emitSizeMismatchError(); + ++numVariadic; + } + + // If we are in a non-rewrite context, we don't need to do anything more. + // Zero-values is a valid constraint on the operation. + if (parserContext != ParserContext::Rewrite) + return success(); + + // Otherwise, when in a rewrite we may need to provide values to match the + // ODS signature of the operation to create. + + // If we only have one variadic value, just use an empty list. + if (numVariadic == 1) + return success(); + + // Otherwise, create dummy values for each of the entries so that we + // adhere to the ODS signature. + for (unsigned i = 0, e = odsValues.size(); i < e; ++i) { + values.push_back( + ast::RangeExpr::create(ctx, loc, /*elements=*/llvm::None, rangeTy)); + } + return success(); } + + // Verify that the number of values provided matches the number of value + // groups ODS expects. + if (odsValues.size() != values.size()) + return emitSizeMismatchError(); + auto diagFn = [&](ast::Diagnostic &diag) { diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name), *odsOpLoc); diff --git a/mlir/test/lib/Transforms/TestDialectConversion.pdll b/mlir/test/lib/Transforms/TestDialectConversion.pdll --- a/mlir/test/lib/Transforms/TestDialectConversion.pdll +++ b/mlir/test/lib/Transforms/TestDialectConversion.pdll @@ -10,9 +10,8 @@ #include "mlir/Transforms/DialectConversion.pdll" /// Change the result type of a producer. -// FIXME: We shouldn't need to specify arguments for the result cast. -Pattern => replace op(args: ValueRange) -> (results: TypeRange) - with op(args) -> (convertTypes(results)); +Pattern => replace op -> (results: TypeRange) + with op -> (convertTypes(results)); /// Pass through test.return conversion. Pattern => replace op(args: ValueRange) diff --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll --- a/mlir/test/mlir-pdll/Parser/expr.pdll +++ b/mlir/test/mlir-pdll/Parser/expr.pdll @@ -213,6 +213,34 @@ // ----- +// Test that we don't need to provide values if all elements +// are optional. + +#include "include/ops.td" + +// CHECK: Module +// CHECK: -OperationExpr {{.*}} Type> +// CHECK-NOT: `Operands` +// CHECK-NOT: `Result Types` +// CHECK: -OperationExpr {{.*}} Type> +// CHECK-NOT: `Operands` +// CHECK-NOT: `Result Types` +// CHECK: -OperationExpr {{.*}} Type> +// CHECK: `Operands` +// CHECK: -RangeExpr {{.*}} Type +// CHECK: -RangeExpr {{.*}} Type +// CHECK: `Result Types` +// CHECK: -RangeExpr {{.*}} Type +// CHECK: -RangeExpr {{.*}} Type +Pattern { + rewrite op() -> () with { + op -> (); + op -> (); + }; +} + +// ----- + //===----------------------------------------------------------------------===// // TupleExpr //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-pdll/Parser/include/ops.td b/mlir/test/mlir-pdll/Parser/include/ops.td --- a/mlir/test/mlir-pdll/Parser/include/ops.td +++ b/mlir/test/mlir-pdll/Parser/include/ops.td @@ -28,3 +28,8 @@ def OpMultipleSingleResult : Op { let results = (outs I64:$result, I64:$result2); } + +def OpMultiVariadic : Op { + let arguments = (ins Variadic:$operands, Variadic:$operand2); + let results = (outs Variadic:$results, Variadic:$results2); +}