diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -2534,19 +2534,15 @@ class fir_ArithmeticOp traits = []> : fir_Op, - Results<(outs AnyType)> { - let parser = "return impl::parseOneResultSameOperandTypeOp(parser, result);"; - - let printer = "return printBinaryOp(this->getOperation(), p);"; + Results<(outs AnyType:$result)> { + let assemblyFormat = "operands attr-dict `:` type($result)"; } class fir_UnaryArithmeticOp traits = []> : fir_Op, - Results<(outs AnyType)> { - let parser = "return impl::parseOneResultSameOperandTypeOp(parser, result);"; - - let printer = "return printUnaryOp(this->getOperation(), p);"; + Results<(outs AnyType:$result)> { + let assemblyFormat = "operands attr-dict `:` type($result)"; } def fir_ConstcOp : fir_Op<"constc", [NoSideEffect]> { diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -3211,26 +3211,6 @@ return mlir::success(); } -/// Generic pretty-printer of a binary operation -static void printBinaryOp(Operation *op, OpAsmPrinter &p) { - assert(op->getNumOperands() == 2 && "binary op must have two operands"); - assert(op->getNumResults() == 1 && "binary op must have one result"); - - p << ' ' << op->getOperand(0) << ", " << op->getOperand(1); - p.printOptionalAttrDict(op->getAttrs()); - p << " : " << op->getResult(0).getType(); -} - -/// Generic pretty-printer of an unary operation -static void printUnaryOp(Operation *op, OpAsmPrinter &p) { - assert(op->getNumOperands() == 1 && "unary op must have one operand"); - assert(op->getNumResults() == 1 && "unary op must have one result"); - - p << ' ' << op->getOperand(0); - p.printOptionalAttrDict(op->getAttrs()); - p << " : " << op->getResult(0).getType(); -} - bool fir::isReferenceLike(mlir::Type type) { return type.isa() || type.isa() || type.isa(); diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -419,8 +419,7 @@ let arguments = (ins type:$arg); let results = (outs resultType:$res); let builders = [LLVM_OneResultOpBuilder]; - let parser = [{ return mlir::impl::parseCastOp(parser, result); }]; - let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }]; + let assemblyFormat = "$arg attr-dict `:` type($arg) `to` type($res)"; } def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "CreateBitCast", LLVM_AnyNonAggregate, LLVM_AnyNonAggregate> { diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -383,7 +383,6 @@ }]; let hasFolder = 1; - let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td @@ -34,6 +34,7 @@ let results = (outs SPV_ScalarOrVectorOrCoopMatrixOf:$result ); + let assemblyFormat = "operands attr-dict `:` type($result)"; } class SPV_ArithmeticUnaryOp:$result ); - let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; - let printer = [{ return impl::printOneResultOp(getOperation(), p); }]; // No additional verification needed in addition to the ODS-generated ones. let hasVerifier = 0; } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td @@ -21,7 +21,9 @@ // All the operands type used in bit instructions are SPV_Integer. SPV_BinaryOp; + [NoSideEffect, SameOperandsAndResultType])> { + let assemblyFormat = "operands attr-dict `:` type($result)"; +} class SPV_BitFieldExtractOp traits = []> : SPV_Op:$result ); - - let parser = [{ return mlir::impl::parseCastOp(parser, result); }]; - let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }]; + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) `to` type($result) + }]; } // ----- @@ -85,9 +85,9 @@ SPV_ScalarOrVectorOrPtr:$result ); - let parser = [{ return mlir::impl::parseCastOp(parser, result); }]; - let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }]; - + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) `to` type($result) + }]; let hasCanonicalizer = 1; } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td @@ -72,10 +72,6 @@ SPV_ScalarOrVectorOf:$result ); - let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; - - let printer = [{ return impl::printOneResultOp(getOperation(), p); }]; - let hasVerifier = 0; } @@ -83,7 +79,10 @@ // return type matches. class SPV_GLSLBinaryArithmeticOp traits = []> : - SPV_GLSLBinaryOp; + SPV_GLSLBinaryOp { + let assemblyFormat = "operands attr-dict `:` type($result)"; +} // Base class for GLSL ternary ops. class SPV_GLSLTernaryArithmeticOp:$result ); - let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; - - let printer = [{ return impl::printOneResultOp(getOperation(), p); }]; + let parser = [{ return parseOneResultSameOperandTypeOp(parser, result); }]; + let printer = [{ return printOneResultOp(getOperation(), p); }]; let hasVerifier = 0; } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td @@ -71,10 +71,6 @@ SPV_ScalarOrVectorOf:$result ); - let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; - - let printer = [{ return impl::printOneResultOp(getOperation(), p); }]; - let hasVerifier = 0; } @@ -82,7 +78,10 @@ // return type matches. class SPV_OCLBinaryArithmeticOp traits = []> : - SPV_OCLBinaryOp; + SPV_OCLBinaryOp { + let assemblyFormat = "operands attr-dict `:` type($result)"; +} // ----- diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -14,6 +14,7 @@ #define SHAPE_OPS include "mlir/Dialect/Shape/IR/ShapeBase.td" +include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -331,7 +332,9 @@ }]; } -def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> { +def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [ + DeclareOpInterfaceMethods, NoSideEffect + ]> { let summary = "Creates a dimension tensor from a shape"; let description = [{ Converts a shape to a 1D integral tensor of extents. The number of elements @@ -624,7 +627,9 @@ }]; } -def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> { +def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [ + DeclareOpInterfaceMethods, NoSideEffect + ]> { let summary = "Casts between index types of the shape and standard dialect"; let description = [{ Converts a `shape.size` to a standard index. This operation and its diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1897,26 +1897,9 @@ }; //===----------------------------------------------------------------------===// -// Common Operation Folders/Parsers/Printers +// CastOpInterface utilities //===----------------------------------------------------------------------===// -// These functions are out-of-line implementations of the methods in UnaryOp and -// BinaryOp, which avoids them being template instantiated/duplicated. -namespace impl { -ParseResult parseOneResultOneOperandTypeOp(OpAsmParser &parser, - OperationState &result); - -void buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs, - Value rhs); -ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, - OperationState &result); - -// Prints the given binary `op` in custom assembly form if both the two operands -// and the result have the same time. Otherwise, prints the generic assembly -// form. -void printOneResultOp(Operation *op, OpAsmPrinter &p); -} // namespace impl - // These functions are out-of-line implementations of the methods in // CastOpInterface, which avoids them being template instantiated/duplicated. namespace impl { @@ -1927,20 +1910,6 @@ /// Attempt to verify the given cast operation. LogicalResult verifyCastInterfaceOp( Operation *op, function_ref areCastCompatible); - -// TODO: Remove the parse/print/build here (new ODS functionality obsoletes the -// need for them, but some older ODS code in `std` still depends on them). -void buildCastOp(OpBuilder &builder, OperationState &result, Value source, - Type destType); -ParseResult parseCastOp(OpAsmParser &parser, OperationState &result); -void printCastOp(Operation *op, OpAsmPrinter &p); -// TODO: These methods are deprecated in favor of CastOpInterface. Remove them -// when all uses have been updated. Also, consider adding functionality to -// CastOpInterface to be able to perform the ChainedTensorCast canonicalization -// generically. -Value foldCastOp(Operation *op); -LogicalResult verifyCastOp(Operation *op, - function_ref areCastCompatible); } // namespace impl } // namespace mlir diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -65,10 +65,6 @@ return NoneType::get(type.getContext()); } -LogicalResult memref::CastOp::verify() { - return impl::verifyCastOp(*this, areCastCompatible); -} - //===----------------------------------------------------------------------===// // AllocOp / AllocaOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -64,6 +64,54 @@ // Common utility functions //===----------------------------------------------------------------------===// +static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, + OperationState &result) { + SmallVector ops; + Type type; + // If the operand list is in-between parentheses, then we have a generic form. + // (see the fallback in `printOneResultOp`). + SMLoc loc = parser.getCurrentLocation(); + if (!parser.parseOptionalLParen()) { + if (parser.parseOperandList(ops) || parser.parseRParen() || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColon() || parser.parseType(type)) + return failure(); + auto fnType = type.dyn_cast(); + if (!fnType) { + parser.emitError(loc, "expected function type"); + return failure(); + } + if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands)) + return failure(); + result.addTypes(fnType.getResults()); + return success(); + } + return failure(parser.parseOperandList(ops) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(type) || + parser.resolveOperands(ops, type, result.operands) || + parser.addTypeToList(type, result.types)); +} + +static void printOneResultOp(Operation *op, OpAsmPrinter &p) { + assert(op->getNumResults() == 1 && "op should have one result"); + + // If not all the operand and result types are the same, just use the + // generic assembly form to avoid omitting information in printing. + auto resultType = op->getResult(0).getType(); + if (llvm::any_of(op->getOperandTypes(), + [&](Type type) { return type != resultType; })) { + p.printGenericOp(op, /*printOpName=*/false); + return; + } + + p << ' '; + p.printOperands(op->getOperands()); + p.printOptionalAttrDict(op->getAttrs()); + // Now we can output only one type for all operands and the result. + p << " : " << resultType; +} + /// Returns true if the given op is a function-like op or nested in a /// function-like op without a module-like op in the middle. static bool isNestedInFunctionOpInterface(Operation *op) { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1692,7 +1692,7 @@ // `IntegerAttr`s which makes constant folding simple. if (Attribute arg = operands[0]) return arg; - return impl::foldCastOp(*this); + return OpFoldResult(); } void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, @@ -1700,6 +1700,12 @@ patterns.add(context); } +bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + return inputs[0].isa() && outputs[0].isa(); +} + //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// @@ -1750,7 +1756,7 @@ OpFoldResult ToExtentTensorOp::fold(ArrayRef operands) { if (!operands[0]) - return impl::foldCastOp(*this); + return OpFoldResult(); Builder builder(getContext()); auto shape = llvm::to_vector<6>( operands[0].cast().getValues()); @@ -1759,6 +1765,21 @@ return DenseIntElementsAttr::get(type, shape); } +bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + if (auto inputTensor = inputs[0].dyn_cast()) { + if (!inputTensor.getElementType().isa() || + inputTensor.getRank() != 1 || !inputTensor.isDynamicDim(0)) + return false; + } else if (!inputs[0].isa()) { + return false; + } + + TensorType outputTensor = outputs[0].dyn_cast(); + return outputTensor && outputTensor.getElementType().isa(); +} + //===----------------------------------------------------------------------===// // ReduceOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -1125,69 +1125,7 @@ } //===----------------------------------------------------------------------===// -// BinaryOp implementation -//===----------------------------------------------------------------------===// - -// These functions are out-of-line implementations of the methods in BinaryOp, -// which avoids them being template instantiated/duplicated. - -void impl::buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs, - Value rhs) { - assert(lhs.getType() == rhs.getType()); - result.addOperands({lhs, rhs}); - result.types.push_back(lhs.getType()); -} - -ParseResult impl::parseOneResultSameOperandTypeOp(OpAsmParser &parser, - OperationState &result) { - SmallVector ops; - Type type; - // If the operand list is in-between parentheses, then we have a generic form. - // (see the fallback in `printOneResultOp`). - SMLoc loc = parser.getCurrentLocation(); - if (!parser.parseOptionalLParen()) { - if (parser.parseOperandList(ops) || parser.parseRParen() || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColon() || parser.parseType(type)) - return failure(); - auto fnType = type.dyn_cast(); - if (!fnType) { - parser.emitError(loc, "expected function type"); - return failure(); - } - if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands)) - return failure(); - result.addTypes(fnType.getResults()); - return success(); - } - return failure(parser.parseOperandList(ops) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(type) || - parser.resolveOperands(ops, type, result.operands) || - parser.addTypeToList(type, result.types)); -} - -void impl::printOneResultOp(Operation *op, OpAsmPrinter &p) { - assert(op->getNumResults() == 1 && "op should have one result"); - - // If not all the operand and result types are the same, just use the - // generic assembly form to avoid omitting information in printing. - auto resultType = op->getResult(0).getType(); - if (llvm::any_of(op->getOperandTypes(), - [&](Type type) { return type != resultType; })) { - p.printGenericOp(op, /*printOpName=*/false); - return; - } - - p << ' '; - p.printOperands(op->getOperands()); - p.printOptionalAttrDict(op->getAttrs()); - // Now we can output only one type for all operands and the result. - p << " : " << resultType; -} - -//===----------------------------------------------------------------------===// -// CastOp implementation +// CastOpInterface //===----------------------------------------------------------------------===// /// Attempt to fold the given cast operation. @@ -1232,50 +1170,6 @@ return success(); } -void impl::buildCastOp(OpBuilder &builder, OperationState &result, Value source, - Type destType) { - result.addOperands(source); - result.addTypes(destType); -} - -ParseResult impl::parseCastOp(OpAsmParser &parser, OperationState &result) { - OpAsmParser::OperandType srcInfo; - Type srcType, dstType; - return failure(parser.parseOperand(srcInfo) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(srcType) || - parser.resolveOperand(srcInfo, srcType, result.operands) || - parser.parseKeywordType("to", dstType) || - parser.addTypeToList(dstType, result.types)); -} - -void impl::printCastOp(Operation *op, OpAsmPrinter &p) { - p << ' ' << op->getOperand(0); - p.printOptionalAttrDict(op->getAttrs()); - p << " : " << op->getOperand(0).getType() << " to " - << op->getResult(0).getType(); -} - -Value impl::foldCastOp(Operation *op) { - // Identity cast - if (op->getOperand(0).getType() == op->getResult(0).getType()) - return op->getOperand(0); - return nullptr; -} - -LogicalResult -impl::verifyCastOp(Operation *op, - function_ref areCastCompatible) { - auto opType = op->getOperand(0).getType(); - auto resType = op->getResult(0).getType(); - if (!areCastCompatible(opType, resType)) - return op->emitError("operand type ") - << opType << " and result type " << resType - << " are cast incompatible"; - - return success(); -} - //===----------------------------------------------------------------------===// // Misc. utils //===----------------------------------------------------------------------===//