diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -610,7 +610,7 @@ imposed by one's target platform. }]; let assemblyFormat = [{ - $format attr-dict ($args^ `:` type($args))? + $format attr-dict (`(` $args^ `:` type($args) `)`)? }]; } diff --git a/mlir/include/mlir/Dialect/NVGPU/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/NVGPU.td --- a/mlir/include/mlir/Dialect/NVGPU/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/NVGPU.td @@ -186,13 +186,13 @@ Example: ```mlir - %0 = gpu.device_async_create_group + %0 = gpu.device_async_create_group(%cp0) ``` }]; let results = (outs NVGPU_DeviceAsyncToken:$asyncToken); let arguments = (ins Variadic:$inputTokens); let assemblyFormat = [{ - $inputTokens attr-dict + `(` $inputTokens `)` attr-dict }]; } diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -472,7 +472,7 @@ Variadic>:$replValues); let assemblyFormat = [{ $operation `with` (`(` $replValues^ `:` type($replValues) `)`)? - ($replOperation^)? attr-dict + (`(` $replOperation^ `)`)? attr-dict }]; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -820,15 +820,15 @@ %w0 = shape.cstr_broadcastable [2,2], [3,1,2] // Passing %w1 = shape.cstr_broadcastable [2,2], [3,2] // Failure %w2 = shape.cstr_eq [1,2], [1,2], [1,2] // Passing - %wf = shape.assuming_all %w0, %w1 // Failure - %wt = shape.assuming_all %w0, %w2 // Passing + %wf = shape.assuming_all(%w0, %w1) // Failure + %wt = shape.assuming_all(%w0, %w2) // Passing ``` }]; let arguments = (ins Variadic:$inputs); let results = (outs Shape_WitnessType:$result); - let assemblyFormat = "$inputs attr-dict"; + let assemblyFormat = "`(` $inputs `)` attr-dict"; let hasFolder = 1; let hasCanonicalizer = 1; diff --git a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll --- a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll +++ b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll @@ -77,7 +77,7 @@ // CHECK: %[[FIRST:.*]] = operation "test.first" // CHECK: %[[SECOND:.*]] = operation "test.second" // CHECK: rewrite %[[FIRST]] { -// CHECK: replace %[[FIRST]] with %[[SECOND]] +// CHECK: replace %[[FIRST]] with (%[[SECOND]]) Pattern TupleMemberAccessNumber { let firstOp = op; let secondOp = op(firstOp); @@ -91,7 +91,7 @@ // CHECK: %[[FIRST:.*]] = operation "test.first" // CHECK: %[[SECOND:.*]] = operation "test.second" // CHECK: rewrite %[[FIRST]] { -// CHECK: replace %[[FIRST]] with %[[SECOND]] +// CHECK: replace %[[FIRST]] with (%[[SECOND]]) Pattern TupleMemberAccessName { let firstOp = op; let secondOp = op(firstOp); diff --git a/mlir/test/mlir-tblgen/op-format-verify.td b/mlir/test/mlir-tblgen/op-format-verify-attribute-colon.td rename from mlir/test/mlir-tblgen/op-format-verify.td rename to mlir/test/mlir-tblgen/op-format-verify-attribute-colon.td diff --git a/mlir/test/mlir-tblgen/op-format-verify-trailing-optional-operand.td b/mlir/test/mlir-tblgen/op-format-verify-trailing-optional-operand.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/op-format-verify-trailing-optional-operand.td @@ -0,0 +1,25 @@ +// RUN: mlir-tblgen -gen-op-decls -asmformat-error-is-fatal=false -I %S/../../include %s -o=%t 2>&1 | FileCheck %s + +include "mlir/IR/OpBase.td" + +def TestDialect : Dialect { + let name = "test"; +} +class TestFormat_Op traits = []> + : Op { + let assemblyFormat = fmt; +} + +//===----------------------------------------------------------------------===// +// Format ambiguity caused by trailing optional operand(s) +//===----------------------------------------------------------------------===// + +// Test trailing optional operand. +// CHECK: error: format ambiguity caused by trailing +def AmbiguousTypeA : TestFormat_Op<"attr-dict $a">, + Arguments<(ins Optional:$a)>; + +// Test trailing optional operand followed by optional elements. +// CHECK: error: format ambiguity caused by trailing +def AmbiguousTypeB : TestFormat_Op<"$a attr-dict">, + Arguments<(ins Optional:$a)>; diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -2202,6 +2202,11 @@ LogicalResult verifyAttributeColonType(SMLoc loc, ArrayRef elements); + /// Verify that does not end with an optional operand, which can cause + /// ambiguities with the next operation's results. + LogicalResult + verifyTrailingOptionalOperand(SMLoc loc, ArrayRef elements); + /// Verify the state of operation operands within the format. LogicalResult verifyOperands(SMLoc loc, @@ -2338,6 +2343,10 @@ if (failed(verifyAttributeColonType(loc, elements))) return failure(); + // Check for format ambiguities. + if (failed(verifyTrailingOptionalOperand(loc, elements))) + return failure(); + // Check for VariadicOfVariadic variables. The segment attribute of those // variables will be infered. for (const NamedTypeConstraint *var : seenOperands) { @@ -2472,6 +2481,53 @@ [&](const Twine &msg) { return emitError(loc, msg); }, elements); } +static LogicalResult +verifyTrailingOperands(function_ref emitError, + ArrayRef elements, + FormatElement *anchor = nullptr) { + for (FormatElement *element : llvm::reverse(elements)) { + // The anchor element is always required. + if (element == anchor) + return success(); + + // Check for a variable-length operand. + if (auto *operandVar = dyn_cast(element)) { + const NamedTypeConstraint *operand = operandVar->getVar(); + if (!operand->isVariableLength()) + return success(); + return emitError("format ambiguity caused by trailing optional or " + "variadic operand in assembly format"); + } + + // Recurse on optional groups. + if (auto *optional = dyn_cast(element)) { + if (failed(verifyTrailingOperands( + emitError, optional->getThenElements().drop_front(), + optional->getAnchor())) || + // The guard element cannot be an optional operand. + failed(verifyTrailingOperands( + emitError, optional->getThenElements().take_front(1))) || + failed( + verifyTrailingOperands(emitError, optional->getElseElements()))) + return failure(); + } else if (!isOptionallyParsed(element)) { + // Found a non-optionally parsed element that isn't an operand. + return success(); + } + } + // Iterated back to the start. + return success(); +} + +LogicalResult OpFormatParser::verifyTrailingOptionalOperand( + SMLoc loc, ArrayRef elements) { + // Terminators can have trailing operands. + if (op.getTrait("::mlir::OpTrait::IsTerminator")) + return success(); + return verifyTrailingOperands( + [&](const Twine &msg) { return emitError(loc, msg); }, elements); +} + LogicalResult OpFormatParser::verifyOperands( SMLoc loc, llvm::StringMap &variableTyResolver) { // Check that all of the operands are within the format, and their types can