diff --git a/mlir/examples/toy/Ch2/mlir/Dialect.cpp b/mlir/examples/toy/Ch2/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch2/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch2/mlir/Dialect.cpp @@ -71,7 +71,7 @@ /// A generalized printer for binary operations. It prints in two different /// forms depending on if all of the types match. static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) { - printer << op->getName() << " " << op->getOperands(); + printer << " " << op->getOperands(); printer.printOptionalAttrDict(op->getAttrs()); printer << " : "; @@ -121,7 +121,7 @@ /// The 'OpAsmPrinter' class is a stream that allows for formatting /// strings, attributes, operands, types, etc. static void print(mlir::OpAsmPrinter &printer, ConstantOp op) { - printer << "toy.constant "; + printer << " "; printer.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"}); printer << op.value(); } diff --git a/mlir/examples/toy/Ch3/mlir/Dialect.cpp b/mlir/examples/toy/Ch3/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch3/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch3/mlir/Dialect.cpp @@ -71,7 +71,7 @@ /// A generalized printer for binary operations. It prints in two different /// forms depending on if all of the types match. static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) { - printer << op->getName() << " " << op->getOperands(); + printer << " " << op->getOperands(); printer.printOptionalAttrDict(op->getAttrs()); printer << " : "; @@ -121,7 +121,7 @@ /// The 'OpAsmPrinter' class is a stream that allows for formatting /// strings, attributes, operands, types, etc. static void print(mlir::OpAsmPrinter &printer, ConstantOp op) { - printer << "toy.constant "; + printer << " "; printer.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"}); printer << op.value(); } diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -127,7 +127,7 @@ /// A generalized printer for binary operations. It prints in two different /// forms depending on if all of the types match. static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) { - printer << op->getName() << " " << op->getOperands(); + printer << " " << op->getOperands(); printer.printOptionalAttrDict(op->getAttrs()); printer << " : "; @@ -177,7 +177,7 @@ /// The 'OpAsmPrinter' class is a stream that allows for formatting /// strings, attributes, operands, types, etc. static void print(mlir::OpAsmPrinter &printer, ConstantOp op) { - printer << "toy.constant "; + printer << " "; printer.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"}); printer << op.value(); } diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -127,7 +127,7 @@ /// A generalized printer for binary operations. It prints in two different /// forms depending on if all of the types match. static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) { - printer << op->getName() << " " << op->getOperands(); + printer << " " << op->getOperands(); printer.printOptionalAttrDict(op->getAttrs()); printer << " : "; @@ -177,7 +177,7 @@ /// The 'OpAsmPrinter' class is a stream that allows for formatting /// strings, attributes, operands, types, etc. static void print(mlir::OpAsmPrinter &printer, ConstantOp op) { - printer << "toy.constant "; + printer << " "; printer.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"}); printer << op.value(); } diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -127,7 +127,7 @@ /// A generalized printer for binary operations. It prints in two different /// forms depending on if all of the types match. static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) { - printer << op->getName() << " " << op->getOperands(); + printer << " " << op->getOperands(); printer.printOptionalAttrDict(op->getAttrs()); printer << " : "; @@ -177,7 +177,7 @@ /// The 'OpAsmPrinter' class is a stream that allows for formatting /// strings, attributes, operands, types, etc. static void print(mlir::OpAsmPrinter &printer, ConstantOp op) { - printer << "toy.constant "; + printer << " "; printer.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"}); printer << op.value(); } diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -114,7 +114,7 @@ /// A generalized printer for binary operations. It prints in two different /// forms depending on if all of the types match. static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) { - printer << op->getName() << " " << op->getOperands(); + printer << " " << op->getOperands(); printer.printOptionalAttrDict(op->getAttrs()); printer << " : "; @@ -164,7 +164,7 @@ /// The 'OpAsmPrinter' class is a stream that allows for formatting /// strings, attributes, operands, types, etc. static void print(mlir::OpAsmPrinter &printer, ConstantOp op) { - printer << "toy.constant "; + printer << " "; printer.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"}); printer << op.value(); } diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -537,8 +537,7 @@ let builders = [OpBuilder<(ins), [{ // empty}]>]; - let parser = [{ return parseReturnOp(parser, result); }]; - let printer = [{ p << getOperationName(); }]; + let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; let verifier = [{ return ::verify(*this); }]; } @@ -552,8 +551,7 @@ terminator takes no operands. }]; - let parser = [{ return success(); }]; - let printer = [{ p << getOperationName(); }]; + let assemblyFormat = "attr-dict"; } def GPU_YieldOp : GPU_Op<"yield", [NoSideEffect, Terminator]>, @@ -681,8 +679,7 @@ Either none or all work items of a workgroup need to execute this op in convergence. }]; - let parser = [{ return success(); }]; - let printer = [{ p << getOperationName(); }]; + let assemblyFormat = "attr-dict"; } def GPU_GPUModuleOp : GPU_Op<"module", [ @@ -733,8 +730,7 @@ This op terminates the only block inside the only region of a `gpu.module`. }]; - let parser = [{ return success(); }]; - let printer = [{ p << getOperationName(); }]; + let assemblyFormat = "attr-dict"; } def GPU_HostRegisterOp : GPU_Op<"host_register">, diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -709,8 +709,7 @@ } def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> { string llvmBuilder = [{ builder.CreateUnreachable(); }]; - let parser = [{ return success(); }]; - let printer = [{ p << getOperationName(); }]; + let assemblyFormat = "attr-dict"; } def LLVM_SwitchOp : LLVM_TerminatorOp<"switch", diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -162,7 +162,7 @@ let parser = [{ return parseROCDLMubufLoadOp(parser, result); }]; let printer = [{ Operation *op = this->getOperation(); - p << op->getName() << " " << op->getOperands() + p << " " << op->getOperands() << " : " << op->getResultTypes(); }]; } @@ -184,7 +184,7 @@ let parser = [{ return parseROCDLMubufStoreOp(parser, result); }]; let printer = [{ Operation *op = this->getOperation(); - p << op->getName() << " " << op->getOperands() + p << " " << op->getOperands() << " : " << vdata().getType(); }]; } diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -82,7 +82,7 @@ /// linalg::(Tensor)CollapseShapeOp. template void printReshapeOp(OpAsmPrinter &p, ReshapeLikeOp op) { - p << op.getOperationName() << ' ' << op.src() << " ["; + p << ' ' << op.src() << " ["; llvm::interleaveComma(op.reassociation(), p, [&](const Attribute &attr) { p << '['; @@ -223,8 +223,8 @@ ShapedType resultType = reshapeOp.getResultType(); Optional> reassociationIndices = composeReassociationIndices(srcReshapeOp.getReassociationIndices(), - reshapeOp.getReassociationIndices(), - rewriter.getContext()); + reshapeOp.getReassociationIndices(), + rewriter.getContext()); if (!reassociationIndices) return failure(); rewriter.replaceOpWithNewOp( diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -180,6 +180,7 @@ // The fallback for the printer is to print it the generic assembly form. static void print(Operation *op, OpAsmPrinter &p); + static void printOpName(Operation *op, OpAsmPrinter &p); /// Mutability management is handled by the OpWrapper/OpConstWrapper classes, /// so we can cast it away here. @@ -1776,8 +1777,8 @@ static std::enable_if_t::value, AbstractOperation::PrintAssemblyFn> getPrintAssemblyFnImpl() { - return [](Operation *op, OpAsmPrinter &parser) { - return OpState::print(op, parser); + return [](Operation *op, OpAsmPrinter &printer) { + return OpState::print(op, printer); }; } /// The internal implementation of `getPrintAssemblyFn` that is invoked when @@ -1789,6 +1790,7 @@ return &printAssembly; } static void printAssembly(Operation *op, OpAsmPrinter &p) { + OpState::printOpName(op, p); return cast(op).print(p); } /// Implementation of `VerifyInvariantsFn` AbstractOperation hook. diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -98,7 +98,9 @@ ArrayRef elidedAttrs = {}) = 0; /// Print the entire operation with the default generic assembly form. - virtual void printGenericOp(Operation *op) = 0; + /// If `printOpName` is true, then the operation name is printed (the default) + /// otherwise it is ommited and the print will start with the operand list. + virtual void printGenericOp(Operation *op, bool printOpName = true) = 0; /// Prints a region. /// If 'printEntryBlockArgs' is false, the arguments of the diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -516,7 +516,7 @@ } static void print(OpAsmPrinter &p, AffineApplyOp op) { - p << AffineApplyOp::getOperationName() << " " << op.mapAttr(); + p << " " << op.mapAttr(); printDimAndSymbolList(op.operand_begin(), op.operand_end(), op.getAffineMap().getNumDims(), p); p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"map"}); @@ -997,7 +997,7 @@ } void AffineDmaStartOp::print(OpAsmPrinter &p) { - p << "affine.dma_start " << getSrcMemRef() << '['; + p << " " << getSrcMemRef() << '['; p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices()); p << "], " << getDstMemRef() << '['; p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices()); @@ -1150,7 +1150,7 @@ } void AffineDmaWaitOp::print(OpAsmPrinter &p) { - p << "affine.dma_wait " << getTagMemRef() << '['; + p << " " << getTagMemRef() << '['; SmallVector operands(getTagIndices()); p.printAffineMapOfSSAIds(getTagMapAttr(), operands); p << "], "; @@ -1527,7 +1527,7 @@ } static void print(OpAsmPrinter &p, AffineForOp op) { - p << op.getOperationName() << ' '; + p << ' '; p.printOperand(op.getBody()->getArgument(0)); p << " = "; printBound(op.getLowerBoundMapAttr(), op.getLowerBoundOperands(), "max", p); @@ -2103,7 +2103,7 @@ static void print(OpAsmPrinter &p, AffineIfOp op) { auto conditionAttr = op->getAttrOfType(op.getConditionAttrName()); - p << "affine.if " << conditionAttr; + p << " " << conditionAttr; printDimAndSymbolList(op.operand_begin(), op.operand_end(), conditionAttr.getValue().getNumDims(), p); p.printOptionalArrowTypeList(op.getResultTypes()); @@ -2248,7 +2248,7 @@ } static void print(OpAsmPrinter &p, AffineLoadOp op) { - p << "affine.load " << op.getMemRef() << '['; + p << " " << op.getMemRef() << '['; if (AffineMapAttr mapAttr = op->getAttrOfType(op.getMapAttrName())) p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands()); @@ -2364,7 +2364,7 @@ } static void print(OpAsmPrinter &p, AffineStoreOp op) { - p << "affine.store " << op.getValueToStore(); + p << " " << op.getValueToStore(); p << ", " << op.getMemRef() << '['; if (AffineMapAttr mapAttr = op->getAttrOfType(op.getMapAttrName())) @@ -2418,7 +2418,7 @@ template static void printAffineMinMaxOp(OpAsmPrinter &p, T op) { - p << op.getOperationName() << ' ' << op->getAttr(T::getMapAttrName()); + p << ' ' << op->getAttr(T::getMapAttrName()); auto operands = op.getOperands(); unsigned numDims = op.map().getNumDims(); p << '(' << operands.take_front(numDims) << ')'; @@ -2713,7 +2713,7 @@ } static void print(OpAsmPrinter &p, AffinePrefetchOp op) { - p << AffinePrefetchOp::getOperationName() << " " << op.memref() << '['; + p << " " << op.memref() << '['; AffineMapAttr mapAttr = op->getAttrOfType(op.getMapAttrName()); if (mapAttr) { SmallVector operands(op.getMapOperands()); @@ -3101,7 +3101,7 @@ } static void print(OpAsmPrinter &p, AffineParallelOp op) { - p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = ("; + p << " (" << op.getBody()->getArguments() << ") = ("; printMinMaxBound(p, op.lowerBoundsMapAttr(), op.lowerBoundsGroupsAttr(), op.getLowerBoundsOperands(), "max"); p << ") to ("; @@ -3460,7 +3460,7 @@ } static void print(OpAsmPrinter &p, AffineVectorLoadOp op) { - p << "affine.vector_load " << op.getMemRef() << '['; + p << " " << op.getMemRef() << '['; if (AffineMapAttr mapAttr = op->getAttrOfType(op.getMapAttrName())) p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands()); @@ -3552,7 +3552,7 @@ } static void print(OpAsmPrinter &p, AffineVectorStoreOp op) { - p << "affine.vector_store " << op.getValueToStore(); + p << " " << op.getValueToStore(); p << ", " << op.getMemRef() << '['; if (AffineMapAttr mapAttr = op->getAttrOfType(op.getMapAttrName())) diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -123,7 +123,7 @@ static void print(OpAsmPrinter &p, IncludeOp &op) { bool standardInclude = op.is_standard_include(); - p << IncludeOp::getOperationName() << " "; + p << " "; if (standardInclude) p << "<"; p << "\"" << op.include() << "\""; diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -299,8 +299,8 @@ } static void printShuffleOp(OpAsmPrinter &p, ShuffleOp op) { - p << ShuffleOp::getOperationName() << ' ' << op.getOperands() << ' ' - << op.mode() << " : " << op.value().getType(); + p << ' ' << op.getOperands() << ' ' << op.mode() << " : " + << op.value().getType(); } static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &state) { @@ -441,7 +441,7 @@ static void printLaunchOp(OpAsmPrinter &p, LaunchOp op) { // Print the launch configuration. - p << LaunchOp::getOperationName() << ' ' << op.getBlocksKeyword(); + p << ' ' << op.getBlocksKeyword(); printSizeAssignment(p, op.getGridSize(), op.getGridSizeOperandValues(), op.getBlockIds()); p << ' ' << op.getThreadsKeyword(); @@ -781,7 +781,7 @@ /// Prints a GPU Func op. static void printGPUFuncOp(OpAsmPrinter &p, GPUFuncOp op) { - p << GPUFuncOp::getOperationName() << ' '; + p << ' '; p.printSymbolName(op.getName()); FunctionType type = op.getType(); @@ -862,18 +862,6 @@ // ReturnOp //===----------------------------------------------------------------------===// -static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) { - llvm::SmallVector operands; - llvm::SmallVector types; - if (parser.parseOperandList(operands) || - parser.parseOptionalColonTypeList(types) || - parser.resolveOperands(operands, types, parser.getCurrentLocation(), - result.operands)) - return failure(); - - return success(); -} - static LogicalResult verify(gpu::ReturnOp returnOp) { GPUFuncOp function = returnOp->getParentOfType(); @@ -930,7 +918,7 @@ } static void print(OpAsmPrinter &p, GPUModuleOp op) { - p << op.getOperationName() << ' '; + p << ' '; p.printSymbolName(op.getName()); p.printOptionalAttrDictWithKeyword(op->getAttrs(), {SymbolTable::getSymbolAttrName()}); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -73,15 +73,15 @@ // Printing/parsing for LLVM::CmpOp. //===----------------------------------------------------------------------===// static void printICmpOp(OpAsmPrinter &p, ICmpOp &op) { - p << op.getOperationName() << " \"" << stringifyICmpPredicate(op.predicate()) - << "\" " << op.getOperand(0) << ", " << op.getOperand(1); + p << " \"" << stringifyICmpPredicate(op.predicate()) << "\" " + << op.getOperand(0) << ", " << op.getOperand(1); p.printOptionalAttrDict(op->getAttrs(), {"predicate"}); p << " : " << op.lhs().getType(); } static void printFCmpOp(OpAsmPrinter &p, FCmpOp &op) { - p << op.getOperationName() << " \"" << stringifyFCmpPredicate(op.predicate()) - << "\" " << op.getOperand(0) << ", " << op.getOperand(1); + p << " \"" << stringifyFCmpPredicate(op.predicate()) << "\" " + << op.getOperand(0) << ", " << op.getOperand(1); p.printOptionalAttrDict(processFMFAttr(op->getAttrs()), {"predicate"}); p << " : " << op.lhs().getType(); } @@ -161,7 +161,7 @@ auto funcTy = FunctionType::get(op.getContext(), {op.arraySize().getType()}, {op.getType()}); - p << op.getOperationName() << ' ' << op.arraySize() << " x " << elemTy; + p << ' ' << op.arraySize() << " x " << elemTy; if (op.alignment().hasValue() && *op.alignment() != 0) p.printOptionalAttrDict(op->getAttrs()); else @@ -421,7 +421,7 @@ } static void printLoadOp(OpAsmPrinter &p, LoadOp &op) { - p << op.getOperationName() << ' '; + p << ' '; if (op.volatile_()) p << "volatile "; p << op.addr(); @@ -483,7 +483,7 @@ } static void printStoreOp(OpAsmPrinter &p, StoreOp &op) { - p << op.getOperationName() << ' '; + p << ' '; if (op.volatile_()) p << "volatile "; p << op.value() << ", " << op.addr(); @@ -549,7 +549,7 @@ auto callee = op.callee(); bool isDirect = callee.hasValue(); - p << op.getOperationName() << ' '; + p << ' '; // Either function name or pointer if (isDirect) @@ -710,7 +710,7 @@ } static void printLandingpadOp(OpAsmPrinter &p, LandingpadOp &op) { - p << op.getOperationName() << (op.cleanup() ? " cleanup " : " "); + p << (op.cleanup() ? " cleanup " : " "); // Clauses for (auto value : op.getOperands()) { @@ -848,7 +848,7 @@ // Print the direct callee if present as a function attribute, or an indirect // callee (first operand) otherwise. - p << op.getOperationName() << ' '; + p << ' '; if (isDirect) p.printSymbolName(callee.getValue()); else @@ -960,8 +960,8 @@ } static void printExtractElementOp(OpAsmPrinter &p, ExtractElementOp &op) { - p << op.getOperationName() << ' ' << op.vector() << "[" << op.position() - << " : " << op.position().getType() << "]"; + p << ' ' << op.vector() << "[" << op.position() << " : " + << op.position().getType() << "]"; p.printOptionalAttrDict(op->getAttrs()); p << " : " << op.vector().getType(); } @@ -1007,7 +1007,7 @@ //===----------------------------------------------------------------------===// static void printExtractValueOp(OpAsmPrinter &p, ExtractValueOp &op) { - p << op.getOperationName() << ' ' << op.container() << op.position(); + p << ' ' << op.container() << op.position(); p.printOptionalAttrDict(op->getAttrs(), {"position"}); p << " : " << op.container().getType(); } @@ -1159,8 +1159,8 @@ //===----------------------------------------------------------------------===// static void printInsertElementOp(OpAsmPrinter &p, InsertElementOp &op) { - p << op.getOperationName() << ' ' << op.value() << ", " << op.vector() << "[" - << op.position() << " : " << op.position().getType() << "]"; + p << ' ' << op.value() << ", " << op.vector() << "[" << op.position() << " : " + << op.position().getType() << "]"; p.printOptionalAttrDict(op->getAttrs()); p << " : " << op.vector().getType(); } @@ -1209,8 +1209,7 @@ //===----------------------------------------------------------------------===// static void printInsertValueOp(OpAsmPrinter &p, InsertValueOp &op) { - p << op.getOperationName() << ' ' << op.value() << ", " << op.container() - << op.position(); + p << ' ' << op.value() << ", " << op.container() << op.position(); p.printOptionalAttrDict(op->getAttrs(), {"position"}); p << " : " << op.container().getType(); } @@ -1266,7 +1265,6 @@ //===----------------------------------------------------------------------===// static void printReturnOp(OpAsmPrinter &p, ReturnOp op) { - p << op.getOperationName(); p.printOptionalAttrDict(op->getAttrs()); assert(op.getNumOperands() <= 1); @@ -1409,7 +1407,7 @@ } static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) { - p << op.getOperationName() << ' ' << stringifyLinkage(op.linkage()) << ' '; + p << ' ' << stringifyLinkage(op.linkage()) << ' '; if (op.unnamed_addr()) p << stringifyUnnamedAddr(*op.unnamed_addr()) << ' '; if (op.constant()) @@ -1634,8 +1632,7 @@ } static void printShuffleVectorOp(OpAsmPrinter &p, ShuffleVectorOp &op) { - p << op.getOperationName() << ' ' << op.v1() << ", " << op.v2() << " " - << op.mask(); + p << ' ' << op.v1() << ", " << op.v2() << " " << op.mask(); p.printOptionalAttrDict(op->getAttrs(), {"mask"}); p << " : " << op.v1().getType() << ", " << op.v2().getType(); } @@ -1798,7 +1795,7 @@ // helper functions. Drops "void" result since it cannot be parsed back. Skips // the external linkage since it is the default value. static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) { - p << op.getOperationName() << ' '; + p << ' '; if (op.linkage() != LLVM::Linkage::External) p << stringifyLinkage(op.linkage()) << ' '; p.printSymbolName(op.getName()); @@ -2000,9 +1997,8 @@ //===----------------------------------------------------------------------===// static void printAtomicRMWOp(OpAsmPrinter &p, AtomicRMWOp &op) { - p << op.getOperationName() << ' ' << stringifyAtomicBinOp(op.bin_op()) << ' ' - << op.ptr() << ", " << op.val() << ' ' - << stringifyAtomicOrdering(op.ordering()) << ' '; + p << ' ' << stringifyAtomicBinOp(op.bin_op()) << ' ' << op.ptr() << ", " + << op.val() << ' ' << stringifyAtomicOrdering(op.ordering()) << ' '; p.printOptionalAttrDict(op->getAttrs(), {"bin_op", "ordering"}); p << " : " << op.res().getType(); } @@ -2071,8 +2067,8 @@ //===----------------------------------------------------------------------===// static void printAtomicCmpXchgOp(OpAsmPrinter &p, AtomicCmpXchgOp &op) { - p << op.getOperationName() << ' ' << op.ptr() << ", " << op.cmp() << ", " - << op.val() << ' ' << stringifyAtomicOrdering(op.success_ordering()) << ' ' + p << ' ' << op.ptr() << ", " << op.cmp() << ", " << op.val() << ' ' + << stringifyAtomicOrdering(op.success_ordering()) << ' ' << stringifyAtomicOrdering(op.failure_ordering()); p.printOptionalAttrDict(op->getAttrs(), {"success_ordering", "failure_ordering"}); @@ -2158,7 +2154,7 @@ static void printFenceOp(OpAsmPrinter &p, FenceOp &op) { StringRef syncscopeKeyword = "syncscope"; - p << op.getOperationName() << ' '; + p << ' '; if (!op->getAttr(syncscopeKeyword).cast().getValue().empty()) p << "syncscope(" << op->getAttr(syncscopeKeyword) << ") "; p << stringifyAtomicOrdering(op.ordering()); diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -37,7 +37,7 @@ //===----------------------------------------------------------------------===// static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) { - p << op->getName() << " " << op->getOperands(); + p << " " << op->getOperands(); if (op->getNumResults() > 0) p << " : " << op->getResultTypes(); } @@ -285,7 +285,6 @@ static void printWMMAMmaF16F16M16N16K16Op(OpAsmPrinter &p, WMMAMmaF16F16M16N16K16Op &op) { - p << op.getOperationName(); p << ' '; p << op.args(); p.printOptionalAttrDict(op->getAttrs(), {}); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -559,7 +559,7 @@ } static void print(OpAsmPrinter &p, GenericOp op) { - p << op.getOperationName() << " "; + p << " "; // Print extra attributes. auto genericAttrNames = op.linalgTraitAttrNames(); @@ -1642,7 +1642,6 @@ //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, linalg::YieldOp op) { - p << op.getOperationName(); if (op.getNumOperands() > 0) p << ' ' << op.getOperands(); p.printOptionalAttrDict(op->getAttrs()); @@ -1797,9 +1796,8 @@ } static void print(OpAsmPrinter &p, TiledLoopOp op) { - p << op.getOperationName() << " (" << op.getInductionVars() << ") = (" - << op.lowerBound() << ") to (" << op.upperBound() << ") step (" << op.step() - << ")"; + p << " (" << op.getInductionVars() << ") = (" << op.lowerBound() << ") to (" + << op.upperBound() << ") step (" << op.step() << ")"; if (!op.inputs().empty()) { p << " ins ("; @@ -2092,8 +2090,8 @@ auto src = dimOp.source().template dyn_cast(); if (!src) return failure(); - auto loopOp = dyn_cast( - src.getOwner()->getParent()->getParentOp()); + auto loopOp = + dyn_cast(src.getOwner()->getParent()->getParentOp()); if (!loopOp) return failure(); @@ -2650,7 +2648,6 @@ template static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) { - p << op.getOperationName(); p.printOptionalAttrDict( op->getAttrs(), /*elidedAttrs=*/{"operand_segment_sizes", diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -208,7 +208,7 @@ static void print(OpAsmPrinter &p, AllocaScopeOp &op) { bool printBlockTerminators = false; - p << AllocaScopeOp::getOperationName() << " "; + p << " "; if (!op.results().empty()) { p << " -> (" << op.getResultTypes() << ")"; printBlockTerminators = true; @@ -805,10 +805,9 @@ } void DmaStartOp::print(OpAsmPrinter &p) { - p << getOperationName() << " " << getSrcMemRef() << '[' << getSrcIndices() - << "], " << getDstMemRef() << '[' << getDstIndices() << "], " - << getNumElements() << ", " << getTagMemRef() << '[' << getTagIndices() - << ']'; + p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], " + << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements() + << ", " << getTagMemRef() << '[' << getTagIndices() << ']'; if (isStrided()) p << ", " << getStride() << ", " << getNumElementsPerStride(); @@ -970,8 +969,8 @@ } void DmaWaitOp::print(OpAsmPrinter &p) { - p << getOperationName() << " " << getTagMemRef() << '[' << getTagIndices() - << "], " << getNumElements(); + p << " " << getTagMemRef() << '[' << getTagIndices() << "], " + << getNumElements(); p.printOptionalAttrDict((*this)->getAttrs()); p << " : " << getTagMemRef().getType(); } @@ -1177,7 +1176,7 @@ //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, PrefetchOp op) { - p << PrefetchOp::getOperationName() << " " << op.memref() << '['; + p << " " << op.memref() << '['; p.printOperands(op.indices()); p << ']' << ", " << (op.isWrite() ? "write" : "read"); p << ", locality<" << op.localityHint(); @@ -2242,7 +2241,7 @@ // transpose $in $permutation attr-dict : type($in) `to` type(results) static void print(OpAsmPrinter &p, TransposeOp op) { - p << "memref.transpose " << op.in() << " " << op.permutation(); + p << " " << op.in() << " " << op.permutation(); p.printOptionalAttrDict(op->getAttrs(), {TransposeOp::getPermutationAttrName()}); p << " : " << op.in().getType() << " to " << op.getType(); @@ -2319,7 +2318,7 @@ } static void print(OpAsmPrinter &p, ViewOp op) { - p << op.getOperationName() << ' ' << op.getOperand(0) << '['; + p << ' ' << op.getOperand(0) << '['; p.printOperand(op.byte_shift()); p << "][" << op.sizes() << ']'; p.printOptionalAttrDict(op->getAttrs()); diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -382,8 +382,6 @@ } static void print(OpAsmPrinter &printer, ParallelOp &op) { - printer << ParallelOp::getOperationName(); - // async()? if (Value async = op.async()) printer << " " << ParallelOp::getAsyncKeyword() << "(" << async << ": " @@ -599,8 +597,6 @@ } static void print(OpAsmPrinter &printer, LoopOp &op) { - printer << LoopOp::getOperationName(); - unsigned execMapping = op.exec_mapping(); if (execMapping & OpenACCExecMapping::GANG) { printer << " " << LoopOp::getGangKeyword(); diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -142,8 +142,6 @@ } static void printParallelOp(OpAsmPrinter &p, ParallelOp op) { - p << "omp.parallel"; - if (auto ifCond = op.if_expr_var()) p << " if(" << ifCond << " : " << ifCond.getType() << ")"; @@ -703,9 +701,8 @@ static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) { auto args = op.getRegion().front().getArguments(); - p << op.getOperationName() << " (" << args << ") : " << args[0].getType() - << " = (" << op.lowerBound() << ") to (" << op.upperBound() << ") step (" - << op.step() << ")"; + p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound() + << ") to (" << op.upperBound() << ") step (" << op.step() << ")"; // Print private, firstprivate, shared and copyin parameters auto printDataVars = [&p](StringRef name, OperandRange vars) { diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -114,7 +114,6 @@ } static void print(OpAsmPrinter &p, ExecuteRegionOp op) { - p << ExecuteRegionOp::getOperationName(); p.printOptionalArrowTypeList(op.getResultTypes()); p.printRegion(op.region(), @@ -339,8 +338,8 @@ } static void print(OpAsmPrinter &p, ForOp op) { - p << op.getOperationName() << " " << op.getInductionVar() << " = " - << op.lowerBound() << " to " << op.upperBound() << " step " << op.step(); + p << " " << op.getInductionVar() << " = " << op.lowerBound() << " to " + << op.upperBound() << " step " << op.step(); printInitializationList(p, op.getRegionIterArgs(), op.getIterOperands(), " iter_args"); @@ -1099,7 +1098,7 @@ static void print(OpAsmPrinter &p, IfOp op) { bool printBlockTerminators = false; - p << IfOp::getOperationName() << " " << op.condition(); + p << " " << op.condition(); if (!op.results().empty()) { p << " -> (" << op.getResultTypes() << ")"; // Print yield explicitly if the op defines values. @@ -1762,9 +1761,8 @@ } static void print(OpAsmPrinter &p, ParallelOp op) { - p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = (" - << op.lowerBound() << ") to (" << op.upperBound() << ") step (" << op.step() - << ")"; + p << " (" << op.getBody()->getArguments() << ") = (" << op.lowerBound() + << ") to (" << op.upperBound() << ") step (" << op.step() << ")"; if (!op.initVals().empty()) p << " init (" << op.initVals() << ")"; p.printOptionalArrowTypeList(op.getResultTypes()); @@ -2020,7 +2018,7 @@ } static void print(OpAsmPrinter &p, ReduceOp op) { - p << op.getOperationName() << "(" << op.operand() << ") "; + p << "(" << op.operand() << ") "; p << " : " << op.operand().getType(); p.printRegion(op.reductionOperator()); } @@ -2124,7 +2122,6 @@ /// Prints a `while` op. static void print(OpAsmPrinter &p, scf::WhileOp op) { - p << op.getOperationName(); printInitializationList(p, op.before().front().getArguments(), op.inits(), " "); p << " : "; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -688,7 +688,7 @@ // Prints an atomic update op. static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) { - printer << op->getName() << " \""; + printer << " \""; auto scopeAttr = op->getAttrOfType(kMemoryScopeAttrName); printer << spirv::stringifyScope( static_cast(scopeAttr.getInt())) @@ -757,7 +757,7 @@ static void printGroupNonUniformArithmeticOp(Operation *groupOp, OpAsmPrinter &printer) { - printer << groupOp->getName() << " \"" + printer << " \"" << stringifyScope(static_cast( groupOp->getAttrOfType(kExecutionScopeAttrName) .getInt())) @@ -813,7 +813,7 @@ } static void printUnaryOp(Operation *unaryOp, OpAsmPrinter &printer) { - printer << unaryOp->getName() << ' ' << unaryOp->getOperand(0) << " : " + printer << ' ' << unaryOp->getOperand(0) << " : " << unaryOp->getOperand(0).getType(); } @@ -851,7 +851,7 @@ } static void printLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) { - printer << logicalOp->getName() << ' ' << logicalOp->getOperands() << " : " + printer << ' ' << logicalOp->getOperands() << " : " << logicalOp->getOperand(0).getType(); } @@ -875,8 +875,8 @@ static void printShiftOp(Operation *op, OpAsmPrinter &printer) { Value base = op->getOperand(0); Value shift = op->getOperand(1); - printer << op->getName() << ' ' << base << ", " << shift << " : " - << base.getType() << ", " << shift.getType(); + printer << ' ' << base << ", " << shift << " : " << base.getType() << ", " + << shift.getType(); } static LogicalResult verifyShiftOp(Operation *op) { @@ -1021,7 +1021,7 @@ template static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) { - printer << Op::getOperationName() << ' ' << op.base_ptr() << '[' << indices + printer << ' ' << op.base_ptr() << '[' << indices << "] : " << op.base_ptr().getType() << ", " << indices.getTypes(); } @@ -1114,8 +1114,7 @@ static void print(spirv::AtomicCompareExchangeWeakOp atomOp, OpAsmPrinter &printer) { - printer << spirv::AtomicCompareExchangeWeakOp::getOperationName() << " \"" - << stringifyScope(atomOp.memory_scope()) << "\" \"" + printer << " \"" << stringifyScope(atomOp.memory_scope()) << "\" \"" << stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \"" << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" " << atomOp.getOperands() << " : " << atomOp.pointer().getType(); @@ -1257,8 +1256,7 @@ } static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter &printer) { - printer << spirv::BranchConditionalOp::getOperationName() << ' ' - << branchOp.condition(); + printer << ' ' << branchOp.condition(); if (auto weights = branchOp.branch_weights()) { printer << " ["; @@ -1330,8 +1328,7 @@ static void print(spirv::CompositeConstructOp compositeConstructOp, OpAsmPrinter &printer) { - printer << spirv::CompositeConstructOp::getOperationName() << " " - << compositeConstructOp.constituents() << " : " + printer << " " << compositeConstructOp.constituents() << " : " << compositeConstructOp.getResult().getType(); } @@ -1405,9 +1402,9 @@ static void print(spirv::CompositeExtractOp compositeExtractOp, OpAsmPrinter &printer) { - printer << spirv::CompositeExtractOp::getOperationName() << ' ' - << compositeExtractOp.composite() << compositeExtractOp.indices() - << " : " << compositeExtractOp.composite().getType(); + printer << ' ' << compositeExtractOp.composite() + << compositeExtractOp.indices() << " : " + << compositeExtractOp.composite().getType(); } static LogicalResult verify(spirv::CompositeExtractOp compExOp) { @@ -1479,10 +1476,9 @@ static void print(spirv::CompositeInsertOp compositeInsertOp, OpAsmPrinter &printer) { - printer << spirv::CompositeInsertOp::getOperationName() << " " - << compositeInsertOp.object() << ", " << compositeInsertOp.composite() - << compositeInsertOp.indices() << " : " - << compositeInsertOp.object().getType() << " into " + printer << " " << compositeInsertOp.object() << ", " + << compositeInsertOp.composite() << compositeInsertOp.indices() + << " : " << compositeInsertOp.object().getType() << " into " << compositeInsertOp.composite().getType(); } @@ -1505,7 +1501,7 @@ } static void print(spirv::ConstantOp constOp, OpAsmPrinter &printer) { - printer << spirv::ConstantOp::getOperationName() << ' ' << constOp.value(); + printer << ' ' << constOp.value(); if (constOp.getType().isa()) printer << " : " << constOp.getType(); } @@ -1748,8 +1744,8 @@ } static void print(spirv::EntryPointOp entryPointOp, OpAsmPrinter &printer) { - printer << spirv::EntryPointOp::getOperationName() << " \"" - << stringifyExecutionModel(entryPointOp.execution_model()) << "\" "; + printer << " \"" << stringifyExecutionModel(entryPointOp.execution_model()) + << "\" "; printer.printSymbolName(entryPointOp.fn()); auto interfaceVars = entryPointOp.interface().getValue(); if (!interfaceVars.empty()) { @@ -1802,7 +1798,7 @@ } static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) { - printer << spirv::ExecutionModeOp::getOperationName() << " "; + printer << " "; printer.printSymbolName(execModeOp.fn()); printer << " \"" << stringifyExecutionMode(execModeOp.execution_mode()) << "\""; @@ -1868,7 +1864,7 @@ static void print(spirv::FuncOp fnOp, OpAsmPrinter &printer) { // Print function name, signature, and control. - printer << spirv::FuncOp::getOperationName() << " "; + printer << " "; printer.printSymbolName(fnOp.sym_name()); auto fnType = fnOp.getType(); function_like_impl::printFunctionSignature(printer, fnOp, fnType.getInputs(), @@ -2080,7 +2076,6 @@ auto *op = varOp.getOperation(); SmallVector elidedAttrs{ spirv::attributeName()}; - printer << spirv::GlobalVariableOp::getOperationName(); // Print variable name. printer << ' '; @@ -2217,8 +2212,7 @@ static void print(spirv::SubgroupBlockReadINTELOp blockReadOp, OpAsmPrinter &printer) { SmallVector elidedAttrs; - printer << spirv::SubgroupBlockReadINTELOp::getOperationName() << " " - << blockReadOp.ptr(); + printer << " " << blockReadOp.ptr(); printer << " : " << blockReadOp.getType(); } @@ -2261,8 +2255,7 @@ static void print(spirv::SubgroupBlockWriteINTELOp blockWriteOp, OpAsmPrinter &printer) { SmallVector elidedAttrs; - printer << spirv::SubgroupBlockWriteINTELOp::getOperationName() << " " - << blockWriteOp.ptr() << ", " << blockWriteOp.value(); + printer << " " << blockWriteOp.ptr() << ", " << blockWriteOp.value(); printer << " : " << blockWriteOp.value().getType(); } @@ -2331,8 +2324,7 @@ SmallVector elidedAttrs; StringRef sc = stringifyStorageClass( loadOp.ptr().getType().cast().getStorageClass()); - printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" " - << loadOp.ptr(); + printer << " \"" << sc << "\" " << loadOp.ptr(); printMemoryAccessAttribute(loadOp, printer, elidedAttrs); @@ -2372,7 +2364,6 @@ static void print(spirv::LoopOp loopOp, OpAsmPrinter &printer) { auto *op = loopOp.getOperation(); - printer << spirv::LoopOp::getOperationName(); auto control = loopOp.loop_control(); if (control != spirv::LoopControl::None) printer << " control(" << spirv::stringifyLoopControl(control) << ")"; @@ -2597,8 +2588,6 @@ } static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) { - printer << spirv::ModuleOp::getOperationName(); - if (Optional name = moduleOp.getName()) { printer << ' '; printer.printSymbolName(*name); @@ -2770,8 +2759,6 @@ static void print(spirv::SelectionOp selectionOp, OpAsmPrinter &printer) { auto *op = selectionOp.getOperation(); - - printer << spirv::SelectionOp::getOperationName(); auto control = selectionOp.selection_control(); if (control != spirv::SelectionControl::None) printer << " control(" << spirv::stringifySelectionControl(control) << ")"; @@ -2903,7 +2890,7 @@ } static void print(spirv::SpecConstantOp constOp, OpAsmPrinter &printer) { - printer << spirv::SpecConstantOp::getOperationName() << ' '; + printer << ' '; printer.printSymbolName(constOp.sym_name()); if (auto specID = constOp->getAttrOfType(kSpecIdAttrName)) printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')'; @@ -2956,8 +2943,7 @@ SmallVector elidedAttrs; StringRef sc = stringifyStorageClass( storeOp.ptr().getType().cast().getStorageClass()); - printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" " - << storeOp.ptr() << ", " << storeOp.value(); + printer << " \"" << sc << "\" " << storeOp.ptr() << ", " << storeOp.value(); printMemoryAccessAttribute(storeOp, printer, elidedAttrs); @@ -3043,8 +3029,6 @@ static void print(spirv::VariableOp varOp, OpAsmPrinter &printer) { SmallVector elidedAttrs{ spirv::attributeName()}; - printer << spirv::VariableOp::getOperationName(); - // Print optional initializer if (varOp.getNumOperands() != 0) printer << " init(" << varOp.initializer() << ")"; @@ -3154,8 +3138,8 @@ } static void print(spirv::CooperativeMatrixLoadNVOp M, OpAsmPrinter &printer) { - printer << spirv::CooperativeMatrixLoadNVOp::getOperationName() << " " - << M.pointer() << ", " << M.stride() << ", " << M.columnmajor(); + printer << " " << M.pointer() << ", " << M.stride() << ", " + << M.columnmajor(); // Print optional memory access attribute. if (auto memAccess = M.memory_access()) printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]"; @@ -3209,8 +3193,7 @@ static void print(spirv::CooperativeMatrixStoreNVOp coopMatrix, OpAsmPrinter &printer) { - printer << spirv::CooperativeMatrixStoreNVOp::getOperationName() << " " - << coopMatrix.pointer() << ", " << coopMatrix.object() << ", " + printer << " " << coopMatrix.pointer() << ", " << coopMatrix.object() << ", " << coopMatrix.stride() << ", " << coopMatrix.columnmajor(); // Print optional memory access attribute. if (auto memAccess = coopMatrix.memory_access()) @@ -3288,7 +3271,7 @@ static void print(spirv::CopyMemoryOp copyMemory, OpAsmPrinter &printer) { auto *op = copyMemory.getOperation(); - printer << spirv::CopyMemoryOp::getOperationName() << ' '; + printer << ' '; StringRef targetStorageClass = stringifyStorageClass(copyMemory.target() @@ -3494,7 +3477,7 @@ } static void print(spirv::SpecConstantCompositeOp op, OpAsmPrinter &printer) { - printer << spirv::SpecConstantCompositeOp::getOperationName() << " "; + printer << " "; printer.printSymbolName(op.sym_name()); printer << " ("; auto constituents = op.constituents().getValue(); @@ -3570,7 +3553,7 @@ } static void print(spirv::SpecConstantOperationOp op, OpAsmPrinter &printer) { - printer << op.getOperationName() << " wraps "; + printer << " wraps "; printer.printGenericOp(&op.body().front().front()); } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -280,7 +280,7 @@ static void print(OpAsmPrinter &p, AssumingOp op) { bool yieldsResults = !op.results().empty(); - p << AssumingOp::getOperationName() << " " << op.witness(); + p << " " << op.witness(); if (yieldsResults) { p << " -> (" << op.getResultTypes() << ")"; } @@ -760,7 +760,7 @@ //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, ConstShapeOp &op) { - p << "shape.const_shape "; + p << " "; p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"}); p << "["; interleaveComma(op.shape().getValues(), p, @@ -1096,7 +1096,7 @@ } void print(OpAsmPrinter &p, FunctionLibraryOp op) { - p << op.getOperationName() << ' '; + p << ' '; p.printSymbolName(op.getName()); p.printOptionalAttrDictWithKeyword( op->getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"}); @@ -1691,7 +1691,7 @@ } static void print(OpAsmPrinter &p, ReduceOp op) { - p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals() + p << '(' << op.shape() << ", " << op.initVals() << ") : " << op.shape().getType(); p.printOptionalArrowTypeList(op.getResultTypes()); p.printRegion(op.region()); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -105,9 +105,7 @@ assert(op->getNumOperands() == 1 && "unary op should have one operand"); assert(op->getNumResults() == 1 && "unary op should have one result"); - int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; - p << op->getName().getStringRef().drop_front(stdDotLen) << ' ' - << op->getOperand(0); + p << ' ' << op->getOperand(0); p.printOptionalAttrDict(op->getAttrs()); p << " : " << op->getOperand(0).getType(); } @@ -127,9 +125,7 @@ return; } - int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; - p << op->getName().getStringRef().drop_front(stdDotLen) << ' ' - << op->getOperand(0) << ", " << op->getOperand(1); + p << ' ' << op->getOperand(0) << ", " << op->getOperand(1); p.printOptionalAttrDict(op->getAttrs()); // Now we can output only one type for all operands and the result. @@ -152,9 +148,7 @@ return; } - int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; - p << op->getName().getStringRef().drop_front(stdDotLen) << ' ' - << op->getOperand(0) << ", " << op->getOperand(1) << ", " + p << ' ' << op->getOperand(0) << ", " << op->getOperand(1) << ", " << op->getOperand(2); p.printOptionalAttrDict(op->getAttrs()); @@ -165,10 +159,8 @@ /// A custom cast operation printer that omits the "std." prefix from the /// operation names. static void printStandardCastOp(Operation *op, OpAsmPrinter &p) { - int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; - p << op->getName().getStringRef().drop_front(stdDotLen) << ' ' - << op->getOperand(0) << " : " << op->getOperand(0).getType() << " to " - << op->getResult(0).getType(); + p << ' ' << op->getOperand(0) << " : " << op->getOperand(0).getType() + << " to " << op->getResult(0).getType(); } void StandardOpsDialect::initialize() { @@ -465,7 +457,7 @@ } static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) { - p << op.getOperationName() << ' ' << op.memref() << "[" << op.indices() + p << ' ' << op.memref() << "[" << op.indices() << "] : " << op.memref().getType(); p.printRegion(op.body()); p.printOptionalAttrDict(op->getAttrs()); @@ -1133,7 +1125,7 @@ //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, ConstantOp &op) { - p << "constant "; + p << " "; p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"}); if (op->getAttrs().size() > 1) @@ -1641,7 +1633,7 @@ } static void print(OpAsmPrinter &p, SelectOp op) { - p << "select " << op.getOperands(); + p << " " << op.getOperands(); p.printOptionalAttrDict(op->getAttrs()); p << " : "; if (ShapedType condType = op.getCondition().getType().dyn_cast()) diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -337,7 +337,7 @@ } static void print(OpAsmPrinter &p, ReductionOp op) { - p << op.getOperationName() << " \"" << op.kind() << "\", " << op.vector(); + p << " \"" << op.kind() << "\", " << op.vector(); if (!op.acc().empty()) p << ", " << op.acc(); p << " : " << op.vector().getType() << " into " << op.dest().getType(); @@ -453,7 +453,7 @@ attrs.push_back(attr); auto dictAttr = DictionaryAttr::get(op.getContext(), attrs); - p << op.getOperationName() << " " << dictAttr << " " << op.lhs() << ", "; + p << " " << dictAttr << " " << op.lhs() << ", "; p << op.rhs() << ", " << op.acc(); if (op.masks().size() == 2) p << ", " << op.masks(); @@ -846,7 +846,7 @@ } static void print(OpAsmPrinter &p, vector::ExtractOp op) { - p << op.getOperationName() << " " << op.vector() << op.position(); + p << " " << op.vector() << op.position(); p.printOptionalAttrDict(op->getAttrs(), {"position"}); p << " : " << op.vector().getType(); } @@ -1389,8 +1389,7 @@ } static void print(OpAsmPrinter &p, ShuffleOp op) { - p << op.getOperationName() << " " << op.v1() << ", " << op.v2() << " " - << op.mask(); + p << " " << op.v1() << ", " << op.v2() << " " << op.mask(); p.printOptionalAttrDict(op->getAttrs(), {ShuffleOp::getMaskAttrName()}); p << " : " << op.v1().getType() << ", " << op.v2().getType(); } @@ -1757,7 +1756,7 @@ } static void print(OpAsmPrinter &p, OuterProductOp op) { - p << op.getOperationName() << " " << op.lhs() << ", " << op.rhs(); + p << " " << op.lhs() << ", " << op.rhs(); if (!op.acc().empty()) { p << ", " << op.acc(); p.printOptionalAttrDict(op->getAttrs()); @@ -2317,8 +2316,8 @@ "as permutation_map results: ") << AffineMapAttr::get(permutationMap); for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i) - if (permutationMap.getResult(i).isa() - && !inBounds.getValue()[i].cast().getValue()) + if (permutationMap.getResult(i).isa() && + !inBounds.getValue()[i].cast().getValue()) return op->emitOpError("requires broadcast dimensions to be in-bounds"); } @@ -2404,8 +2403,7 @@ } static void print(OpAsmPrinter &p, TransferReadOp op) { - p << op.getOperationName() << " " << op.source() << "[" << op.indices() - << "], " << op.padding(); + p << " " << op.source() << "[" << op.indices() << "], " << op.padding(); if (op.mask()) p << ", " << op.mask(); printTransferAttrs(p, cast(op.getOperation())); @@ -2573,8 +2571,8 @@ // inBounds. auto dimExpr = permutationMap.getResult(i).dyn_cast(); assert(dimExpr && "Broadcast dims must be in-bounds"); - auto inBounds = isInBounds( - op, /*resultIdx=*/i, /*indicesIdx=*/dimExpr.getPosition()); + auto inBounds = + isInBounds(op, /*resultIdx=*/i, /*indicesIdx=*/dimExpr.getPosition()); newInBounds.push_back(inBounds); // We commit the pattern if it is "more inbounds". changed |= inBounds; @@ -2745,8 +2743,7 @@ } static void print(OpAsmPrinter &p, TransferWriteOp op) { - p << op.getOperationName() << " " << op.vector() << ", " << op.source() << "[" - << op.indices() << "]"; + p << " " << op.vector() << ", " << op.source() << "[" << op.indices() << "]"; if (op.mask()) p << ", " << op.mask(); printTransferAttrs(p, cast(op.getOperation())); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -382,7 +382,7 @@ private: /// Print the given operation in the generic form. - void printGenericOp(Operation *op) override { + void printGenericOp(Operation *op, bool printOpName = true) override { // Consider nested operations for aliases. if (op->getNumRegions() != 0) { for (Region ®ion : op->getRegions()) @@ -2318,7 +2318,7 @@ /// Print the bare location, not including indentation/location/etc. void printOperation(Operation *op); /// Print the given operation in the generic form. - void printGenericOp(Operation *op) override; + void printGenericOp(Operation *op, bool printOpName) override; /// Print the name of the given block. void printBlockName(Block *block); @@ -2509,6 +2509,10 @@ // Otherwise try to dispatch to the dialect, if available. if (Dialect *dialect = op->getDialect()) { if (auto opPrinter = dialect->getOperationPrinter(op)) { + // Print the op name first. + StringRef name = op->getName().getStringRef(); + printEscapedString(name, os); + // Print the op now. opPrinter(op, *this); return; } @@ -2516,13 +2520,16 @@ } // Otherwise print with the generic assembly form. - printGenericOp(op); + printGenericOp(op, /*printOpName=*/true); } -void OperationPrinter::printGenericOp(Operation *op) { - os << '"'; - printEscapedString(op->getName().getStringRef(), os); - os << "\"("; +void OperationPrinter::printGenericOp(Operation *op, bool printOpName) { + if (printOpName) { + os << '"'; + printEscapedString(op->getName().getStringRef(), os); + os << '"'; + } + os << "("; interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); }); os << ')'; diff --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp --- a/mlir/lib/IR/FunctionImplementation.cpp +++ b/mlir/lib/IR/FunctionImplementation.cpp @@ -359,7 +359,7 @@ auto funcName = op->getAttrOfType(SymbolTable::getSymbolAttrName()) .getValue(); - p << op->getName() << ' '; + p << ' '; StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName(); if (auto visibility = op->getAttrOfType(visibilityAttrName)) diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -14,6 +14,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/FoldInterfaces.h" +#include "llvm/ADT/StringExtras.h" #include using namespace mlir; @@ -641,6 +642,13 @@ // The fallback for the printer is to print in the generic assembly form. void OpState::print(Operation *op, OpAsmPrinter &p) { p.printGenericOp(op); } +// The fallback for the printer is to print in the generic assembly form. +void OpState::printOpName(Operation *op, OpAsmPrinter &p) { + StringRef name = op->getName().getStringRef(); + if (name.startswith("std.")) + name = name.drop_front(4); + llvm::printEscapedString(name, p.getStream()); +} /// Emit an error about fatal conditions with this operation, reporting up to /// any diagnostic handlers that may be listening. @@ -1167,6 +1175,24 @@ OperationState &result) { SmallVector ops; Type type; + // If the operand list is in-between parentheses, then we have a generic form. + // (see the fallback in `printOneResultOp`). + llvm::SMLoc loc = parser.getCurrentLocation(); + if (!parser.parseOptionalLParen()) { + if (parser.parseOperandList(ops) || parser.parseRParen() || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColon() || parser.parseType(type)) + return failure(); + auto fnType = type.dyn_cast(); + if (!fnType) { + parser.emitError(loc, "expected function type"); + return failure(); + } + if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands)) + return failure(); + result.addTypes(fnType.getResults()); + return success(); + } return failure(parser.parseOperandList(ops) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type) || @@ -1182,11 +1208,11 @@ auto resultType = op->getResult(0).getType(); if (llvm::any_of(op->getOperandTypes(), [&](Type type) { return type != resultType; })) { - p.printGenericOp(op); + p.printGenericOp(op, /*printOpName=*/false); return; } - p << op->getName() << ' '; + p << ' '; p.printOperands(op->getOperands()); p.printOptionalAttrDict(op->getAttrs()); // Now we can output only one type for all operands and the result. @@ -1257,7 +1283,7 @@ } void impl::printCastOp(Operation *op, OpAsmPrinter &p) { - p << op->getName() << ' ' << op->getOperand(0); + p << ' ' << op->getOperand(0); p.printOptionalAttrDict(op->getAttrs()); p << " : " << op->getOperand(0).getType() << " to " << op->getResult(0).getType(); diff --git a/mlir/test/Dialect/Standard/expand-ops.mlir b/mlir/test/Dialect/Standard/expand-ops.mlir --- a/mlir/test/Dialect/Standard/expand-ops.mlir +++ b/mlir/test/Dialect/Standard/expand-ops.mlir @@ -6,7 +6,7 @@ %x = atomic_rmw maxf %f, %F[%i] : (f32, memref<10xf32>) -> f32 return %x : f32 } -// CHECK: %0 = std.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> { +// CHECK: %0 = generic_atomic_rmw %arg0[%arg2] : memref<10xf32> { // CHECK: ^bb0([[CUR_VAL:%.*]]: f32): // CHECK: [[CMP:%.*]] = cmpf ogt, [[CUR_VAL]], [[f]] : f32 // CHECK: [[SELECT:%.*]] = select [[CMP]], [[CUR_VAL]], [[f]] : f32 diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -318,7 +318,7 @@ StringRef opName = op->getName().getStringRef(); if (opName == "test.dialect_custom_printer") { return [](Operation *op, OpAsmPrinter &printer) { - printer.getStream() << op->getName().getStringRef() << " custom_format"; + printer.getStream() << " custom_format"; }; } return {}; @@ -656,7 +656,6 @@ } static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) { - p << ParseIntegerLiteralOp::getOperationName(); if (unsigned numResults = op->getNumResults()) p << " : " << numResults; } @@ -671,7 +670,7 @@ } static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) { - p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword(); + p << " " << op.keyword(); } //===----------------------------------------------------------------------===// @@ -709,7 +708,7 @@ } static void print(OpAsmPrinter &p, WrappingRegionOp op) { - p << op.getOperationName() << " wraps "; + p << " wraps "; p.printGenericOp(&op.region().front().front()); } @@ -961,8 +960,6 @@ } static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { - p << "test.string_attr_pretty_name"; - // Note that we only need to print the "name" attribute if the asmprinter // result name disagrees with it. This can happen in strange cases, e.g. // when there are conflicts. @@ -1004,7 +1001,7 @@ //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, RegionIfOp op) { - p << RegionIfOp::getOperationName() << " "; + p << " "; p.printOperands(op.getOperands()); p << ": " << op.getOperandTypes(); p.printArrowTypeList(op.getResultTypes()); @@ -1083,7 +1080,6 @@ } static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) { - printer << op.getOperationName(); printer.printRegion( op.getRegion(), /*printEntryBlockArgs=*/false, // This op has a single block without terminators. But explicitly mark diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -2045,16 +2045,6 @@ opClass.addMethodAndPrune("void", "print", "::mlir::OpAsmPrinter &p"); auto &body = method->body(); - // Emit the operation name, trimming the prefix if this is the standard - // dialect. - body << " p << \""; - std::string opName = op.getOperationName(); - if (op.getDialectName() == "std") - body << StringRef(opName).drop_front(4); - else - body << opName; - body << "\";\n"; - // Flags for if we should emit a space, and if the last element was // punctuation. bool shouldEmitSpace = true, lastWasPunctuation = false;