diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h --- a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h @@ -52,7 +52,7 @@ Type type); inline int64_t value() { - return arith::ConstantOp::value().cast().getInt(); + return arith::ConstantOp::getValue().cast().getInt(); } static bool classof(Operation *op); @@ -68,7 +68,7 @@ const APFloat &value, FloatType type); inline APFloat value() { - return arith::ConstantOp::value().cast().getValue(); + return arith::ConstantOp::getValue().cast().getValue(); } static bool classof(Operation *op); @@ -83,7 +83,7 @@ static void build(OpBuilder &builder, OperationState &result, int64_t value); inline int64_t value() { - return arith::ConstantOp::value().cast().getInt(); + return arith::ConstantOp::getValue().cast().getInt(); } static bool classof(Operation *op); diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td @@ -22,6 +22,7 @@ }]; let hasConstantMaterializer = 1; + let emitAccessorPrefix = kEmitAccessorPrefix_Both; } // The predicate indicates the type of the comparison to perform: diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -956,13 +956,7 @@ ]; let extraClassDeclaration = [{ - static StringRef getPredicateAttrName() { return "predicate"; } static arith::CmpIPredicate getPredicateByName(StringRef name); - - arith::CmpIPredicate getPredicate() { - return (arith::CmpIPredicate) (*this)->getAttrOfType( - getPredicateAttrName()).getInt(); - } }]; let hasFolder = 1; @@ -1012,13 +1006,7 @@ ]; let extraClassDeclaration = [{ - static StringRef getPredicateAttrName() { return "predicate"; } static arith::CmpFPredicate getPredicateByName(StringRef name); - - arith::CmpFPredicate getPredicate() { - return (arith::CmpFPredicate) (*this)->getAttrOfType( - getPredicateAttrName()).getInt(); - } }]; let hasFolder = 1; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -47,6 +47,8 @@ /// Name of the target triple attribute. static StringRef getTargetTripleAttrName() { return "llvm.target_triple"; } }]; + + let emitAccessorPrefix = kEmitAccessorPrefix_Both; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -443,7 +443,7 @@ DeclareOpInterfaceMethods, Terminator]> { let arguments = (ins OptionalAttr:$callee, - Variadic:$operands, + Variadic:$callee_operands, Variadic:$normalDestOperands, Variadic:$unwindDestOperands); let results = (outs Variadic); @@ -1006,7 +1006,7 @@ def LLVM_GlobalOp : LLVM_Op<"mlir.global", [IsolatedFromAbove, SingleBlockImplicitTerminator<"ReturnOp">, Symbol]> { let arguments = (ins - TypeAttr:$type, + TypeAttr:$global_type, UnitAttr:$constant, StrAttr:$sym_name, Linkage:$linkage, @@ -1128,7 +1128,7 @@ let extraClassDeclaration = [{ /// Return the LLVM type of the global. Type getType() { - return type(); + return getGlobalType(); } /// Return the initializer attribute if it exists, or a null attribute. Attribute getValueOrNull() { @@ -1810,7 +1810,7 @@ considered undefined behavior at this time. }]; let arguments = ( - ins Variadic:$operands, + ins Variadic, StrAttr:$asm_string, StrAttr:$constraints, UnitAttr:$has_side_effects, diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td @@ -39,6 +39,7 @@ let hasConstantMaterializer = 1; let hasOperationAttrVerify = 1; + let emitAccessorPrefix = kEmitAccessorPrefix_Both; } def Shape_ShapeType : DialectType { let summary = "Returns the value to parent op"; - let arguments = (ins Variadic:$operands); + let arguments = (ins Variadic:$args); let builders = [OpBuilder<(ins), [{ build($_builder, $_state, llvm::None); }]> ]; let verifier = [{ return ::verify(*this); }]; - let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; + let assemblyFormat = "attr-dict ($args^ `:` type($args))?"; } // TODO: Add Ops: if_static, if_ranked @@ -885,13 +885,13 @@ the number and types of parent `shape.assuming` results. }]; - let arguments = (ins Variadic:$operands); + let arguments = (ins Variadic:$args); let builders = [ OpBuilder<(ins), [{ /* nothing to do */ }]>, ]; - let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; + let assemblyFormat = "attr-dict ($args^ `:` type($args))?"; } def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -27,6 +27,7 @@ let cppNamespace = "::mlir"; let dependentDialects = ["arith::ArithmeticDialect"]; let hasConstantMaterializer = 1; + let emitAccessorPrefix = kEmitAccessorPrefix_Both; } // Base class for Standard dialect ops. @@ -295,15 +296,16 @@ let results = (outs AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result); - let regions = (region AnyRegion:$body); + let regions = (region AnyRegion); let skipDefaultBuilders = 1; let builders = [OpBuilder<(ins "Value":$memref, "ValueRange":$ivs)>]; let extraClassDeclaration = [{ + Region &body() { return getRegion(); } // The value stored in memref[ivs]. Value getCurrentValue() { - return body().getArgument(0); + return getRegion().getArgument(0); } MemRefType getMemRefType() { return memref().getType().cast(); @@ -363,7 +365,6 @@ let verifier = ?; let extraClassDeclaration = [{ - Block *getDest(); void setDest(Block *block); /// Erase the operand at 'index' from the operand list. @@ -397,7 +398,7 @@ ``` }]; - let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic); let results = (outs Variadic); let builders = [ @@ -423,8 +424,6 @@ }]>]; let extraClassDeclaration = [{ - StringRef getCallee() { return callee(); } - StringAttr getCalleeAttr() { return calleeAttr().getAttr(); } FunctionType getCalleeType(); /// Get the argument operands to the called function. @@ -442,7 +441,7 @@ }]; let assemblyFormat = [{ - $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) + $callee `(` operands `)` attr-dict `:` functional-type(operands, results) }]; let verifier = ?; } @@ -454,7 +453,7 @@ def CallIndirectOp : Std_Op<"call_indirect", [ CallOpInterface, TypesMatchWith<"callee input types match argument types", - "callee", "operands", + "callee", "callee_operands", "$_self.cast().getInputs()">, TypesMatchWith<"callee result types match result types", "callee", "results", @@ -478,7 +477,8 @@ ``` }]; - let arguments = (ins FunctionType:$callee, Variadic:$operands); + let arguments = (ins FunctionType:$callee, + Variadic:$callee_operands); let results = (outs Variadic:$results); let builders = [ @@ -489,8 +489,6 @@ }]>]; let extraClassDeclaration = [{ - Value getCallee() { return getOperand(0); } - /// Get the argument operands to the called function. operand_range getArgOperands() { return {arg_operand_begin(), arg_operand_end()}; @@ -506,7 +504,8 @@ let verifier = ?; let hasCanonicalizeMethod = 1; - let assemblyFormat = "$callee `(` $operands `)` attr-dict `:` type($callee)"; + let assemblyFormat = + "$callee `(` $callee_operands `)` attr-dict `:` type($callee)"; } //===----------------------------------------------------------------------===// @@ -570,19 +569,6 @@ // These are the indices into the dests list. enum { trueIndex = 0, falseIndex = 1 }; - // The condition operand is the first operand in the list. - Value getCondition() { return getOperand(0); } - - /// Return the destination if the condition is true. - Block *getTrueDest() { - return getSuccessor(trueIndex); - } - - /// Return the destination if the condition is false. - Block *getFalseDest() { - return getSuccessor(falseIndex); - } - // Accessors for operands to the 'true' destination. Value getTrueOperand(unsigned idx) { assert(idx < getNumTrueOperands()); @@ -594,8 +580,6 @@ setOperand(getTrueDestOperandIndex() + idx, value); } - operand_range getTrueOperands() { return trueDestOperands(); } - unsigned getNumTrueOperands() { return getTrueOperands().size(); } /// Erase the operand at 'index' from the true operand list. @@ -613,7 +597,8 @@ setOperand(getFalseDestOperandIndex() + idx, value); } - operand_range getFalseOperands() { return falseDestOperands(); } + operand_range getTrueOperands() { return getTrueDestOperands(); } + operand_range getFalseOperands() { return getFalseDestOperands(); } unsigned getNumFalseOperands() { return getFalseOperands().size(); } @@ -694,8 +679,6 @@ ]; let extraClassDeclaration = [{ - Attribute getValue() { return (*this)->getAttr("value"); } - /// Returns true if a constant operation can be built with the given value /// and result type. static bool isBuildableWith(Attribute value, Type type); @@ -971,12 +954,6 @@ $_state.addTypes(trueValue.getType()); }]>]; - let extraClassDeclaration = [{ - Value getCondition() { return condition(); } - Value getTrueValue() { return true_value(); } - Value getFalseValue() { return false_value(); } - }]; - let hasCanonicalizer = 1; let hasFolder = 1; } diff --git a/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp b/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp --- a/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp +++ b/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp @@ -134,16 +134,16 @@ typeConverter->convertType(getElementTypeOrSelf(op.getResult())) .cast(); auto sourceElementType = - getElementTypeOrSelf(adaptor.in()).cast(); + getElementTypeOrSelf(adaptor.getIn()).cast(); unsigned targetBits = targetElementType.getWidth(); unsigned sourceBits = sourceElementType.getWidth(); if (targetBits == sourceBits) - rewriter.replaceOp(op, adaptor.in()); + rewriter.replaceOp(op, adaptor.getIn()); else if (targetBits < sourceBits) - rewriter.replaceOpWithNewOp(op, targetType, adaptor.in()); + rewriter.replaceOpWithNewOp(op, targetType, adaptor.getIn()); else - rewriter.replaceOpWithNewOp(op, targetType, adaptor.in()); + rewriter.replaceOpWithNewOp(op, targetType, adaptor.getIn()); return success(); } @@ -161,7 +161,7 @@ LogicalResult CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto operandType = adaptor.lhs().getType(); + auto operandType = adaptor.getLhs().getType(); auto resultType = op.getResult().getType(); // Handle the scalar and 1D vector cases. @@ -169,7 +169,7 @@ rewriter.replaceOpWithNewOp( op, typeConverter->convertType(resultType), convertCmpPredicate(op.getPredicate()), - adaptor.lhs(), adaptor.rhs()); + adaptor.getLhs(), adaptor.getRhs()); return success(); } @@ -184,7 +184,7 @@ return rewriter.create( op.getLoc(), llvm1DVectorTy, convertCmpPredicate(op.getPredicate()), - adaptor.lhs(), adaptor.rhs()); + adaptor.getLhs(), adaptor.getRhs()); }, rewriter); @@ -198,7 +198,7 @@ LogicalResult CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto operandType = adaptor.lhs().getType(); + auto operandType = adaptor.getLhs().getType(); auto resultType = op.getResult().getType(); // Handle the scalar and 1D vector cases. @@ -206,7 +206,7 @@ rewriter.replaceOpWithNewOp( op, typeConverter->convertType(resultType), convertCmpPredicate(op.getPredicate()), - adaptor.lhs(), adaptor.rhs()); + adaptor.getLhs(), adaptor.getRhs()); return success(); } @@ -221,7 +221,7 @@ return rewriter.create( op.getLoc(), llvm1DVectorTy, convertCmpPredicate(op.getPredicate()), - adaptor.lhs(), adaptor.rhs()); + adaptor.getLhs(), adaptor.getRhs()); }, rewriter); } diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp --- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp @@ -274,7 +274,7 @@ if (!dstType) return failure(); - auto dstElementsAttr = constOp.value().dyn_cast(); + auto dstElementsAttr = constOp.getValue().dyn_cast(); ShapedType dstAttrType = dstElementsAttr.getType(); if (!dstElementsAttr) return failure(); @@ -358,7 +358,7 @@ // Floating-point types. if (srcType.isa()) { - auto srcAttr = constOp.value().cast(); + auto srcAttr = constOp.getValue().cast(); auto dstAttr = srcAttr; // Floating-point types not supported in the target environment are all @@ -377,7 +377,7 @@ if (srcType.isInteger(1)) { // arith.constant can use 0/1 instead of true/false for i1 values. We need // to handle that here. - auto dstAttr = convertBoolAttr(constOp.value(), rewriter); + auto dstAttr = convertBoolAttr(constOp.getValue(), rewriter); if (!dstAttr) return failure(); rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); @@ -386,7 +386,7 @@ // IndexType or IntegerType. Index values are converted to 32-bit integer // values when converting to SPIR-V. - auto srcAttr = constOp.value().cast(); + auto srcAttr = constOp.getValue().cast(); auto dstAttr = convertIntegerAttr(srcAttr, dstType.cast(), rewriter); if (!dstAttr) @@ -604,7 +604,7 @@ LogicalResult CmpIOpBooleanPattern::matchAndRewrite( arith::CmpIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - Type operandType = op.lhs().getType(); + Type operandType = op.getLhs().getType(); if (!isBoolScalarOrVector(operandType)) return failure(); @@ -631,7 +631,7 @@ LogicalResult CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - Type operandType = op.lhs().getType(); + Type operandType = op.getLhs().getType(); if (isBoolScalarOrVector(operandType)) return failure(); @@ -708,14 +708,14 @@ arith::CmpFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { if (op.getPredicate() == arith::CmpFPredicate::ORD) { - rewriter.replaceOpWithNewOp(op, adaptor.lhs(), - adaptor.rhs()); + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs()); return success(); } if (op.getPredicate() == arith::CmpFPredicate::UNO) { - rewriter.replaceOpWithNewOp(op, adaptor.lhs(), - adaptor.rhs()); + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs()); return success(); } @@ -735,8 +735,8 @@ Location loc = op.getLoc(); - Value lhsIsNan = rewriter.create(loc, adaptor.lhs()); - Value rhsIsNan = rewriter.create(loc, adaptor.rhs()); + Value lhsIsNan = rewriter.create(loc, adaptor.getLhs()); + Value rhsIsNan = rewriter.create(loc, adaptor.getRhs()); Value replace = rewriter.create(loc, lhsIsNan, rhsIsNan); if (op.getPredicate() == arith::CmpFPredicate::ORD) diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -743,7 +743,7 @@ op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); rewriter.create(op->getLoc(), apiFuncName, TypeRange(), - ValueRange({operand, handle, resumePtr.res()})); + ValueRange({operand, handle, resumePtr.getRes()})); rewriter.eraseOp(op); return success(); @@ -771,8 +771,8 @@ // Call async runtime API to execute a coroutine in the managed thread. auto coroHdl = adaptor.handle(); - rewriter.replaceOpWithNewOp(op, TypeRange(), kExecute, - ValueRange({coroHdl, resumePtr.res()})); + rewriter.replaceOpWithNewOp( + op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()})); return success(); } diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -93,7 +93,7 @@ auto elementType = global.getType().cast().getElementType(); Value memory = rewriter.create( - loc, LLVM::LLVMPointerType::get(elementType, global.addr_space()), + loc, LLVM::LLVMPointerType::get(elementType, global.getAddrSpace()), address, ArrayRef{zero, zero}); // Build a memref descriptor pointing to the buffer to plug with the diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -520,7 +520,7 @@ static bool isDefinedByCallTo(Value value, StringRef functionName) { assert(value.getType().isa()); if (auto defOp = value.getDefiningOp()) - return defOp.callee()->equals(functionName); + return defOp.getCallee()->equals(functionName); return false; } diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp @@ -103,15 +103,16 @@ /// Checks whether the given LLVM::CallOp is a vulkan launch call op. bool isVulkanLaunchCallOp(LLVM::CallOp callOp) { - return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch && + return (callOp.getCallee() && + callOp.getCallee().getValue() == kVulkanLaunch && callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands); } /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call /// op. bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) { - return (callOp.callee() && - callOp.callee().getValue() == kCInterfaceVulkanLaunch && + return (callOp.getCallee() && + callOp.getCallee().getValue() == kCInterfaceVulkanLaunch && callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands); } diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -377,7 +377,10 @@ return idx; if (auto constantOp = dimOp.index().getDefiningOp()) - return constantOp.value().cast().getValue().getSExtValue(); + return constantOp.getValue() + .cast() + .getValue() + .getSExtValue(); return llvm::None; } diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -421,7 +421,7 @@ return val; if (auto constOp = val.getDefiningOp()) return rewriter.create(constOp.getLoc(), - constOp.value()); + constOp.getValue()); return {}; }; diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -92,17 +92,17 @@ // Detect whether the comparison is less-than or greater-than, otherwise bail. bool isLess; - if (llvm::find(lessThanPredicates, compare.predicate()) != + if (llvm::find(lessThanPredicates, compare.getPredicate()) != lessThanPredicates.end()) { isLess = true; - } else if (llvm::find(greaterThanPredicates, compare.predicate()) != + } else if (llvm::find(greaterThanPredicates, compare.getPredicate()) != greaterThanPredicates.end()) { isLess = false; } else { return false; } - if (select.condition() != compare.getResult()) + if (select.getCondition() != compare.getResult()) return false; // Detect if the operands are swapped between cmpf and select. Match the @@ -112,10 +112,10 @@ // positions. constexpr unsigned kTrueValue = 1; constexpr unsigned kFalseValue = 2; - bool sameOperands = select.getOperand(kTrueValue) == compare.lhs() && - select.getOperand(kFalseValue) == compare.rhs(); - bool swappedOperands = select.getOperand(kTrueValue) == compare.rhs() && - select.getOperand(kFalseValue) == compare.lhs(); + bool sameOperands = select.getOperand(kTrueValue) == compare.getLhs() && + select.getOperand(kFalseValue) == compare.getRhs(); + bool swappedOperands = select.getOperand(kTrueValue) == compare.getRhs() && + select.getOperand(kFalseValue) == compare.getLhs(); if (!sameOperands && !swappedOperands) return false; diff --git a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp @@ -29,7 +29,7 @@ using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(shape::CstrRequireOp op, PatternRewriter &rewriter) const override { - rewriter.create(op.getLoc(), op.pred(), op.msgAttr()); + rewriter.create(op.getLoc(), op.getPred(), op.getMsgAttr()); rewriter.replaceOpWithNewOp(op, true); return success(); } diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -40,7 +40,7 @@ ConversionPatternRewriter &rewriter) const { // Replace `any` with its first operand. // Any operand would be a valid substitution. - rewriter.replaceOp(op, {adaptor.inputs().front()}); + rewriter.replaceOp(op, {adaptor.getInputs().front()}); return success(); } @@ -57,7 +57,8 @@ if (op.getType().template isa()) return failure(); - rewriter.replaceOpWithNewOp(op, adaptor.lhs(), adaptor.rhs()); + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs()); return success(); } }; @@ -134,7 +135,7 @@ // representing the shape extents, the rank is the extent of the only // dimension in the tensor. SmallVector ranks, rankDiffs; - llvm::append_range(ranks, llvm::map_range(adaptor.shapes(), [&](Value v) { + llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) { return lb.create(v, zero); })); @@ -154,8 +155,9 @@ Value replacement = lb.create( getExtentTensorType(lb.getContext()), ValueRange{maxRank}, [&](OpBuilder &b, Location loc, ValueRange args) { - Value broadcastedDim = getBroadcastedDim( - ImplicitLocOpBuilder(loc, b), adaptor.shapes(), rankDiffs, args[0]); + Value broadcastedDim = + getBroadcastedDim(ImplicitLocOpBuilder(loc, b), adaptor.getShapes(), + rankDiffs, args[0]); b.create(loc, broadcastedDim); }); @@ -187,14 +189,14 @@ auto loc = op.getLoc(); SmallVector extentOperands; - for (auto extent : op.shape()) { + for (auto extent : op.getShape()) { extentOperands.push_back( rewriter.create(loc, extent.getLimitedValue())); } Type indexTy = rewriter.getIndexType(); Value tensor = rewriter.create(loc, indexTy, extentOperands); - Type resultTy = RankedTensorType::get({op.shape().size()}, indexTy); + Type resultTy = RankedTensorType::get({op.getShape().size()}, indexTy); rewriter.replaceOpWithNewOp(op, resultTy, tensor); return success(); } @@ -214,7 +216,7 @@ ConstSizeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { rewriter.replaceOpWithNewOp( - op, op.value().getSExtValue()); + op, op.getValue().getSExtValue()); return success(); } @@ -234,7 +236,7 @@ ConversionPatternRewriter &rewriter) const { // For now, this lowering is only defined on `tensor` operands, not // on shapes. - if (!llvm::all_of(op.shapes(), + if (!llvm::all_of(op.getShapes(), [](Value v) { return !v.getType().isa(); })) return failure(); @@ -248,7 +250,7 @@ // representing the shape extents, the rank is the extent of the only // dimension in the tensor. SmallVector ranks, rankDiffs; - llvm::append_range(ranks, llvm::map_range(adaptor.shapes(), [&](Value v) { + llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) { return lb.create(v, zero); })); @@ -276,10 +278,10 @@ // could reuse the Broadcast lowering entirely, but we redo the work // here to make optimizations easier between the two loops. Value broadcastedDim = getBroadcastedDim( - ImplicitLocOpBuilder(loc, b), adaptor.shapes(), rankDiffs, iv); + ImplicitLocOpBuilder(loc, b), adaptor.getShapes(), rankDiffs, iv); Value broadcastable = iterArgs[0]; - for (auto tup : llvm::zip(adaptor.shapes(), rankDiffs)) { + for (auto tup : llvm::zip(adaptor.getShapes(), rankDiffs)) { Value shape, rankDiff; std::tie(shape, rankDiff) = tup; Value outOfBounds = b.create( @@ -339,16 +341,17 @@ // Derive shape extent directly from shape origin if possible. This // circumvents the necessity to materialize the shape in memory. - if (auto shapeOfOp = op.shape().getDefiningOp()) { - if (shapeOfOp.arg().getType().isa()) { - rewriter.replaceOpWithNewOp(op, shapeOfOp.arg(), - adaptor.dim()); + if (auto shapeOfOp = op.getShape().getDefiningOp()) { + if (shapeOfOp.getArg().getType().isa()) { + rewriter.replaceOpWithNewOp(op, shapeOfOp.getArg(), + adaptor.getDim()); return success(); } } - rewriter.replaceOpWithNewOp( - op, rewriter.getIndexType(), adaptor.shape(), ValueRange{adaptor.dim()}); + rewriter.replaceOpWithNewOp(op, rewriter.getIndexType(), + adaptor.getShape(), + ValueRange{adaptor.getDim()}); return success(); } @@ -370,7 +373,7 @@ if (op.getType().isa()) return failure(); - rewriter.replaceOpWithNewOp(op, adaptor.shape(), 0); + rewriter.replaceOpWithNewOp(op, adaptor.getShape(), 0); return success(); } @@ -390,7 +393,7 @@ ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // For now, this lowering is only defined on `tensor` operands. - if (op.shape().getType().isa()) + if (op.getShape().getType().isa()) return failure(); auto loc = op.getLoc(); @@ -399,12 +402,12 @@ Value one = rewriter.create(loc, 1); Type indexTy = rewriter.getIndexType(); Value rank = - rewriter.create(loc, indexTy, adaptor.shape(), zero); + rewriter.create(loc, indexTy, adaptor.getShape(), zero); auto loop = rewriter.create( - loc, zero, rank, one, op.initVals(), + loc, zero, rank, one, op.getInitVals(), [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { - Value extent = b.create(loc, adaptor.shape(), iv); + Value extent = b.create(loc, adaptor.getShape(), iv); SmallVector mappedValues{iv, extent}; mappedValues.append(args.begin(), args.end()); @@ -468,12 +471,12 @@ LogicalResult ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - if (!llvm::all_of(op.shapes(), + if (!llvm::all_of(op.getShapes(), [](Value v) { return !v.getType().isa(); })) return failure(); Type i1Ty = rewriter.getI1Type(); - if (op.shapes().size() <= 1) { + if (op.getShapes().size() <= 1) { rewriter.replaceOpWithNewOp(op, i1Ty, rewriter.getBoolAttr(true)); return success(); @@ -482,12 +485,12 @@ auto loc = op.getLoc(); Type indexTy = rewriter.getIndexType(); Value zero = rewriter.create(loc, 0); - Value firstShape = adaptor.shapes().front(); + Value firstShape = adaptor.getShapes().front(); Value firstRank = rewriter.create(loc, indexTy, firstShape, zero); Value result = nullptr; // Generate a linear sequence of compares, all with firstShape as lhs. - for (Value shape : adaptor.shapes().drop_front(1)) { + for (Value shape : adaptor.getShapes().drop_front(1)) { Value rank = rewriter.create(loc, indexTy, shape, zero); Value eqRank = rewriter.create(loc, arith::CmpIPredicate::eq, firstRank, rank); @@ -545,7 +548,7 @@ // For ranked tensor arguments, lower to `tensor.from_elements`. auto loc = op.getLoc(); - Value tensor = adaptor.arg(); + Value tensor = adaptor.getArg(); Type tensorTy = tensor.getType(); if (tensorTy.isa()) { @@ -602,16 +605,16 @@ ConversionPatternRewriter &rewriter) const { // Error conditions are not implemented, only lower if all operands and // results are extent tensors. - if (llvm::any_of(ValueRange{op.operand(), op.head(), op.tail()}, + if (llvm::any_of(ValueRange{op.getOperand(), op.getHead(), op.getTail()}, [](Value v) { return v.getType().isa(); })) return failure(); ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value zero = b.create(0); - Value rank = b.create(adaptor.operand(), zero); + Value rank = b.create(adaptor.getOperand(), zero); // index < 0 ? index + rank : index - Value originalIndex = adaptor.index(); + Value originalIndex = adaptor.getIndex(); Value add = b.create(originalIndex, rank); Value indexIsNegative = b.create(arith::CmpIPredicate::slt, originalIndex, zero); @@ -619,10 +622,10 @@ Value one = b.create(1); Value head = - b.create(adaptor.operand(), zero, index, one); + b.create(adaptor.getOperand(), zero, index, one); Value tailSize = b.create(rank, index); - Value tail = - b.create(adaptor.operand(), index, tailSize, one); + Value tail = b.create(adaptor.getOperand(), index, + tailSize, one); rewriter.replaceOp(op, {head, tail}); return success(); } @@ -636,11 +639,11 @@ LogicalResult matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!adaptor.input().getType().isa()) + if (!adaptor.getInput().getType().isa()) return rewriter.notifyMatchFailure(op, "input needs to be a tensor"); rewriter.replaceOpWithNewOp(op, op.getType(), - adaptor.input()); + adaptor.getInput()); return success(); } }; diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -427,7 +427,7 @@ // Generate assertion test. rewriter.setInsertionPointToEnd(opBlock); rewriter.replaceOpWithNewOp( - op, adaptor.arg(), continuationBlock, failureBlock); + op, adaptor.getArg(), continuationBlock, failureBlock); return success(); } @@ -573,9 +573,9 @@ matchAndRewrite(RankOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Type operandType = op.memrefOrTensor().getType(); + Type operandType = op.getMemrefOrTensor().getType(); if (auto unrankedMemRefType = operandType.dyn_cast()) { - UnrankedMemRefDescriptor desc(adaptor.memrefOrTensor()); + UnrankedMemRefDescriptor desc(adaptor.getMemrefOrTensor()); rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); return success(); } @@ -722,7 +722,7 @@ rewriter.getZeroAttr(rewriter.getIntegerType(32))); auto v = rewriter.create( - splatOp.getLoc(), vectorType, undef, adaptor.input(), zero); + splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); int64_t width = splatOp.getType().cast().getDimSize(0); SmallVector zeroValues(width, 0); @@ -767,7 +767,7 @@ loc, typeConverter->convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); Value v = rewriter.create(loc, llvm1DVectorTy, vdesc, - adaptor.input(), zero); + adaptor.getInput(), zero); // Shuffle the value across the desired number of elements. int64_t width = resultType.getDimSize(resultType.getRank() - 1); @@ -791,7 +791,7 @@ /// Try to match the kind of a std.atomic_rmw to determine whether to use a /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg. static Optional matchSimpleAtomicOp(AtomicRMWOp atomicOp) { - switch (atomicOp.kind()) { + switch (atomicOp.getKind()) { case AtomicRMWKind::addf: return LLVM::AtomicBinOp::fadd; case AtomicRMWKind::addi: @@ -825,13 +825,13 @@ auto maybeKind = matchSimpleAtomicOp(atomicOp); if (!maybeKind) return failure(); - auto resultType = adaptor.value().getType(); + auto resultType = adaptor.getValue().getType(); auto memRefType = atomicOp.getMemRefType(); auto dataPtr = - getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(), - adaptor.indices(), rewriter); + getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(), + adaptor.getIndices(), rewriter); rewriter.replaceOpWithNewOp( - atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(), + atomicOp, resultType, *maybeKind, dataPtr, adaptor.getValue(), LLVM::AtomicOrdering::acq_rel); return success(); } @@ -889,9 +889,9 @@ // Compute the loaded value and branch to the loop block. rewriter.setInsertionPointToEnd(initBlock); - auto memRefType = atomicOp.memref().getType().cast(); - auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(), - adaptor.indices(), rewriter); + auto memRefType = atomicOp.getMemref().getType().cast(); + auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(), + adaptor.getIndices(), rewriter); Value init = rewriter.create(loc, dataPtr); rewriter.create(loc, init, loopBlock); diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -154,8 +154,9 @@ LogicalResult SelectOpPattern::matchAndRewrite(SelectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - rewriter.replaceOpWithNewOp( - op, adaptor.condition(), adaptor.true_value(), adaptor.false_value()); + rewriter.replaceOpWithNewOp(op, adaptor.getCondition(), + adaptor.getTrueValue(), + adaptor.getFalseValue()); return success(); } @@ -169,7 +170,7 @@ auto dstVecType = op.getType().dyn_cast(); if (!dstVecType || !spirv::CompositeType::isValid(dstVecType)) return failure(); - SmallVector source(dstVecType.getNumElements(), adaptor.input()); + SmallVector source(dstVecType.getNumElements(), adaptor.getInput()); rewriter.replaceOpWithNewOp(op, dstVecType, source); return success(); diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -121,7 +121,7 @@ auto vecType = constantOp.getType().dyn_cast(); if (!vecType || vecType.getRank() != 2) return false; - return constantOp.value().isa(); + return constantOp.getValue().isa(); } /// Return true if this is a broadcast from scalar to a 2D vector. @@ -329,7 +329,7 @@ llvm::DenseMap &valueMapping) { assert(constantSupportsMMAMatrixType(op)); OpBuilder b(op); - Attribute splat = op.value().cast().getSplatValue(); + Attribute splat = op.getValue().cast().getSplatValue(); auto scalarConstant = b.create(op.getLoc(), splat.getType(), splat); const char *fragType = inferFragType(op); diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -949,7 +949,7 @@ return nullptr; auto vecTy = getVectorType(scalarTy, state.strategy); - auto vecAttr = DenseElementsAttr::get(vecTy, constOp.value()); + auto vecAttr = DenseElementsAttr::get(vecTy, constOp.getValue()); OpBuilder::InsertionGuard guard(state.builder); Operation *parentOp = state.builder.getInsertionBlock()->getParentOp(); @@ -1253,7 +1253,7 @@ Attribute valueAttr = getIdentityValueAttr(reductionKind, scalarTy, state.builder, value.getLoc()); if (auto constOp = dyn_cast_or_null(value.getDefiningOp())) - return constOp.value() == valueAttr; + return constOp.getValue() == valueAttr; return false; } diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -82,7 +82,7 @@ void arith::ConstantOp::getAsmResultNames( function_ref setNameFn) { auto type = getType(); - if (auto intCst = value().dyn_cast()) { + if (auto intCst = getValue().dyn_cast()) { auto intType = type.dyn_cast(); // Sugar i1 constants with 'true' and 'false'. @@ -106,15 +106,15 @@ static LogicalResult verify(arith::ConstantOp op) { auto type = op.getType(); // The value's type must match the return type. - if (op.value().getType() != type) { - return op.emitOpError() << "value type " << op.value().getType() + if (op.getValue().getType() != type) { + return op.emitOpError() << "value type " << op.getValue().getType() << " must match return type: " << type; } // Integer values must be signless. if (type.isa() && !type.cast().isSignless()) return op.emitOpError("integer return type must be signless"); // Any float or elements attribute are acceptable. - if (!op.value().isa()) { + if (!op.getValue().isa()) { return op.emitOpError( "value must be an integer, float, or elements attribute"); } @@ -133,7 +133,7 @@ } OpFoldResult arith::ConstantOp::fold(ArrayRef operands) { - return value(); + return getValue(); } void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, @@ -187,8 +187,8 @@ OpFoldResult arith::AddIOp::fold(ArrayRef operands) { // addi(x, 0) -> x - if (matchPattern(rhs(), m_Zero())) - return lhs(); + if (matchPattern(getRhs(), m_Zero())) + return getLhs(); return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a + b; }); @@ -209,8 +209,8 @@ if (getOperand(0) == getOperand(1)) return Builder(getContext()).getZeroAttr(getType()); // subi(x,0) -> x - if (matchPattern(rhs(), m_Zero())) - return lhs(); + if (matchPattern(getRhs(), m_Zero())) + return getLhs(); return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a - b; }); @@ -229,10 +229,10 @@ OpFoldResult arith::MulIOp::fold(ArrayRef operands) { // muli(x, 0) -> 0 - if (matchPattern(rhs(), m_Zero())) - return rhs(); + if (matchPattern(getRhs(), m_Zero())) + return getRhs(); // muli(x, 1) -> x - if (matchPattern(rhs(), m_One())) + if (matchPattern(getRhs(), m_One())) return getOperand(0); // TODO: Handle the overflow case. @@ -259,10 +259,10 @@ // Fold out division by one. Assumes all tensors of all ones are splats. if (auto rhs = operands[1].dyn_cast_or_null()) { if (rhs.getValue() == 1) - return lhs(); + return getLhs(); } else if (auto rhs = operands[1].dyn_cast_or_null()) { if (rhs.getSplatValue().getValue() == 1) - return lhs(); + return getLhs(); } return div0 ? Attribute() : result; @@ -286,10 +286,10 @@ // Fold out division by one. Assumes all tensors of all ones are splats. if (auto rhs = operands[1].dyn_cast_or_null()) { if (rhs.getValue() == 1) - return lhs(); + return getLhs(); } else if (auto rhs = operands[1].dyn_cast_or_null()) { if (rhs.getSplatValue().getValue() == 1) - return lhs(); + return getLhs(); } return overflowOrDiv0 ? Attribute() : result; @@ -346,10 +346,10 @@ // splats. if (auto rhs = operands[1].dyn_cast_or_null()) { if (rhs.getValue() == 1) - return lhs(); + return getLhs(); } else if (auto rhs = operands[1].dyn_cast_or_null()) { if (rhs.getSplatValue().getValue() == 1) - return lhs(); + return getLhs(); } return overflowOrDiv0 ? Attribute() : result; @@ -395,10 +395,10 @@ // splats. if (auto rhs = operands[1].dyn_cast_or_null()) { if (rhs.getValue() == 1) - return lhs(); + return getLhs(); } else if (auto rhs = operands[1].dyn_cast_or_null()) { if (rhs.getSplatValue().getValue() == 1) - return lhs(); + return getLhs(); } return overflowOrDiv0 ? Attribute() : result; @@ -458,15 +458,15 @@ OpFoldResult arith::AndIOp::fold(ArrayRef operands) { /// and(x, 0) -> 0 - if (matchPattern(rhs(), m_Zero())) - return rhs(); + if (matchPattern(getRhs(), m_Zero())) + return getRhs(); /// and(x, allOnes) -> x APInt intValue; - if (matchPattern(rhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes()) - return lhs(); + if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes()) + return getLhs(); /// and(x, x) -> x - if (lhs() == rhs()) - return rhs(); + if (getLhs() == getRhs()) + return getRhs(); return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a & b; }); @@ -478,11 +478,11 @@ OpFoldResult arith::OrIOp::fold(ArrayRef operands) { /// or(x, 0) -> x - if (matchPattern(rhs(), m_Zero())) - return lhs(); + if (matchPattern(getRhs(), m_Zero())) + return getLhs(); /// or(x, x) -> x - if (lhs() == rhs()) - return rhs(); + if (getLhs() == getRhs()) + return getRhs(); /// or(x, ) -> if (auto rhsAttr = operands[1].dyn_cast_or_null()) if (rhsAttr.getValue().isAllOnes()) @@ -498,10 +498,10 @@ OpFoldResult arith::XOrIOp::fold(ArrayRef operands) { /// xor(x, 0) -> x - if (matchPattern(rhs(), m_Zero())) - return lhs(); + if (matchPattern(getRhs(), m_Zero())) + return getLhs(); /// xor(x, x) -> 0 - if (lhs() == rhs()) + if (getLhs() == getRhs()) return Builder(getContext()).getZeroAttr(getType()); return constFoldBinaryOp(operands, @@ -599,7 +599,7 @@ // Extend ops can only extend to a wider type. template static LogicalResult verifyExtOp(Op op) { - Type srcType = getElementTypeOrSelf(op.in().getType()); + Type srcType = getElementTypeOrSelf(op.getIn().getType()); Type dstType = getElementTypeOrSelf(op.getType()); if (srcType.cast().getWidth() >= dstType.cast().getWidth()) @@ -612,7 +612,7 @@ // Truncate ops can only truncate to a shorter type. template static LogicalResult verifyTruncateOp(Op op) { - Type srcType = getElementTypeOrSelf(op.in().getType()); + Type srcType = getElementTypeOrSelf(op.getIn().getType()); Type dstType = getElementTypeOrSelf(op.getType()); if (srcType.cast().getWidth() <= dstType.cast().getWidth()) @@ -935,7 +935,7 @@ assert(operands.size() == 2 && "cmpi takes two operands"); // cmpi(pred, x, x) - if (lhs() == rhs()) { + if (getLhs() == getRhs()) { auto val = applyCmpPredicateToEqualOperands(getPredicate()); return BoolAttr::get(getContext(), val); } diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp @@ -24,7 +24,7 @@ ConversionPatternRewriter &rewriter) const override { auto tensorType = op.getType().cast(); rewriter.replaceOpWithNewOp( - op, adaptor.in(), + op, adaptor.getIn(), MemRefType::get(tensorType.getShape(), tensorType.getElementType())); return success(); } diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp @@ -23,8 +23,8 @@ Location loc = op.getLoc(); auto signedCeilDivIOp = cast(op); Type type = signedCeilDivIOp.getType(); - Value a = signedCeilDivIOp.lhs(); - Value b = signedCeilDivIOp.rhs(); + Value a = signedCeilDivIOp.getLhs(); + Value b = signedCeilDivIOp.getRhs(); Value plusOne = rewriter.create( loc, rewriter.getIntegerAttr(type, 1)); Value zero = rewriter.create( @@ -79,8 +79,8 @@ Location loc = op.getLoc(); arith::FloorDivSIOp signedFloorDivIOp = cast(op); Type type = signedFloorDivIOp.getType(); - Value a = signedFloorDivIOp.lhs(); - Value b = signedFloorDivIOp.rhs(); + Value a = signedFloorDivIOp.getLhs(); + Value b = signedFloorDivIOp.getRhs(); Value plusOne = rewriter.create(loc, 1, type); Value zero = rewriter.create(loc, 0, type); Value minusOne = rewriter.create(loc, -1, type); diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -578,7 +578,7 @@ Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op)); rewriter.setInsertionPointToEnd(cont->getPrevNode()); - rewriter.create(loc, adaptor.arg(), + rewriter.create(loc, adaptor.getArg(), /*trueDest=*/cont, /*trueArgs=*/ArrayRef(), /*falseDest=*/setupSetErrorBlock(coro), diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -74,17 +74,17 @@ // Printing/parsing for LLVM::CmpOp. //===----------------------------------------------------------------------===// static void printICmpOp(OpAsmPrinter &p, ICmpOp &op) { - p << " \"" << stringifyICmpPredicate(op.predicate()) << "\" " + p << " \"" << stringifyICmpPredicate(op.getPredicate()) << "\" " << op.getOperand(0) << ", " << op.getOperand(1); p.printOptionalAttrDict(op->getAttrs(), {"predicate"}); - p << " : " << op.lhs().getType(); + p << " : " << op.getLhs().getType(); } static void printFCmpOp(OpAsmPrinter &p, FCmpOp &op) { - p << " \"" << stringifyFCmpPredicate(op.predicate()) << "\" " + p << " \"" << stringifyFCmpPredicate(op.getPredicate()) << "\" " << op.getOperand(0) << ", " << op.getOperand(1); p.printOptionalAttrDict(processFMFAttr(op->getAttrs()), {"predicate"}); - p << " : " << op.lhs().getType(); + p << " : " << op.getLhs().getType(); } // ::= `llvm.icmp` string-literal ssa-use `,` ssa-use @@ -159,11 +159,11 @@ static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) { auto elemTy = op.getType().cast().getElementType(); - auto funcTy = FunctionType::get(op.getContext(), {op.arraySize().getType()}, - {op.getType()}); + auto funcTy = FunctionType::get( + op.getContext(), {op.getArraySize().getType()}, {op.getType()}); - p << ' ' << op.arraySize() << " x " << elemTy; - if (op.alignment().hasValue() && *op.alignment() != 0) + p << ' ' << op.getArraySize() << " x " << elemTy; + if (op.getAlignment().hasValue() && *op.getAlignment() != 0) p.printOptionalAttrDict(op->getAttrs()); else p.printOptionalAttrDict(op->getAttrs(), {"alignment"}); @@ -215,7 +215,7 @@ Optional BrOp::getMutableSuccessorOperands(unsigned index) { assert(index == 0 && "invalid successor index"); - return destOperandsMutable(); + return getDestOperandsMutable(); } //===----------------------------------------------------------------------===// @@ -225,7 +225,8 @@ Optional CondBrOp::getMutableSuccessorOperands(unsigned index) { assert(index < getNumSuccessors() && "invalid successor index"); - return index == 0 ? trueDestOperandsMutable() : falseDestOperandsMutable(); + return index == 0 ? getTrueDestOperandsMutable() + : getFalseDestOperandsMutable(); } //===----------------------------------------------------------------------===// @@ -691,7 +692,7 @@ // catch - global addresses only. // Bitcast ops should have global addresses as their args. if (auto bcOp = value.getDefiningOp()) { - if (auto addrOp = bcOp.arg().getDefiningOp()) + if (auto addrOp = bcOp.getArg().getDefiningOp()) continue; return op.emitError("constant clauses expected") .attachNote(bcOp.getLoc()) @@ -771,7 +772,7 @@ bool isIndirect = false; // If this is an indirect call, the callee attribute is missing. - FlatSymbolRefAttr calleeName = op.calleeAttr(); + FlatSymbolRefAttr calleeName = op.getCalleeAttr(); if (!calleeName) { isIndirect = true; if (!op.getNumOperands()) @@ -845,7 +846,7 @@ } static void printCallOp(OpAsmPrinter &p, CallOp &op) { - auto callee = op.callee(); + auto callee = op.getCallee(); bool isDirect = callee.hasValue(); // Print the direct callee if present as a function attribute, or an indirect @@ -962,10 +963,10 @@ } static void printExtractElementOp(OpAsmPrinter &p, ExtractElementOp &op) { - p << ' ' << op.vector() << "[" << op.position() << " : " - << op.position().getType() << "]"; + p << ' ' << op.getVector() << "[" << op.getPosition() << " : " + << op.getPosition().getType() << "]"; p.printOptionalAttrDict(op->getAttrs()); - p << " : " << op.vector().getType(); + p << " : " << op.getVector().getType(); } // ::= `llvm.extractelement` ssa-use `, ` ssa-use @@ -991,16 +992,16 @@ } static LogicalResult verify(ExtractElementOp op) { - Type vectorType = op.vector().getType(); + Type vectorType = op.getVector().getType(); if (!LLVM::isCompatibleVectorType(vectorType)) return op->emitOpError("expected LLVM dialect-compatible vector type for " "operand #1, got") << vectorType; Type valueType = LLVM::getVectorElementType(vectorType); - if (valueType != op.res().getType()) + if (valueType != op.getRes().getType()) return op.emitOpError() << "Type mismatch: extracting from " << vectorType << " should produce " << valueType - << " but this op returns " << op.res().getType(); + << " but this op returns " << op.getRes().getType(); return success(); } @@ -1009,9 +1010,9 @@ //===----------------------------------------------------------------------===// static void printExtractValueOp(OpAsmPrinter &p, ExtractValueOp &op) { - p << ' ' << op.container() << op.position(); + p << ' ' << op.getContainer() << op.getPosition(); p.printOptionalAttrDict(op->getAttrs(), {"position"}); - p << " : " << op.container().getType(); + p << " : " << op.getContainer().getType(); } // Extract the type at `position` in the wrapped LLVM IR aggregate type @@ -1133,9 +1134,9 @@ } OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef operands) { - auto insertValueOp = container().getDefiningOp(); + auto insertValueOp = getContainer().getDefiningOp(); while (insertValueOp) { - if (position() == insertValueOp.position()) + if (getPosition() == insertValueOp.position()) return insertValueOp.value(); insertValueOp = insertValueOp.container().getDefiningOp(); } @@ -1143,16 +1144,16 @@ } static LogicalResult verify(ExtractValueOp op) { - Type valueType = getInsertExtractValueElementType(op.container().getType(), - op.positionAttr(), op); + Type valueType = getInsertExtractValueElementType(op.getContainer().getType(), + op.getPositionAttr(), op); if (!valueType) return failure(); - if (op.res().getType() != valueType) + if (op.getRes().getType() != valueType) return op.emitOpError() - << "Type mismatch: extracting from " << op.container().getType() + << "Type mismatch: extracting from " << op.getContainer().getType() << " should produce " << valueType << " but this op returns " - << op.res().getType(); + << op.getRes().getType(); return success(); } @@ -1339,12 +1340,12 @@ GlobalOp AddressOfOp::getGlobal() { return lookupSymbolInModule((*this)->getParentOp(), - global_name()); + getGlobalName()); } LLVMFuncOp AddressOfOp::getFunction() { return lookupSymbolInModule((*this)->getParentOp(), - global_name()); + getGlobalName()); } static LogicalResult verify(AddressOfOp op) { @@ -1355,7 +1356,7 @@ "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'"); if (global && - LLVM::LLVMPointerType::get(global.getType(), global.addr_space()) != + LLVM::LLVMPointerType::get(global.getType(), global.getAddrSpace()) != op.getResult().getType()) return op.emitOpError( "the type must be a pointer to the type of the referenced global"); @@ -1386,7 +1387,7 @@ bool dsoLocal, ArrayRef attrs) { result.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); - result.addAttribute("type", TypeAttr::get(type)); + result.addAttribute("global_type", TypeAttr::get(type)); if (isConstant) result.addAttribute("constant", builder.getUnitAttr()); if (value) @@ -1400,7 +1401,7 @@ if (alignment != 0) result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); - result.addAttribute(getLinkageAttrName(), + result.addAttribute(::getLinkageAttrName(), LinkageAttr::get(builder.getContext(), linkage)); if (addrSpace != 0) result.addAttribute("addr_space", builder.getI32IntegerAttr(addrSpace)); @@ -1409,15 +1410,15 @@ } static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) { - p << ' ' << stringifyLinkage(op.linkage()) << ' '; - if (auto unnamedAddr = op.unnamed_addr()) { + p << ' ' << stringifyLinkage(op.getLinkage()) << ' '; + if (auto unnamedAddr = op.getUnnamedAddr()) { StringRef str = stringifyUnnamedAddr(*unnamedAddr); if (!str.empty()) p << str << ' '; } - if (op.constant()) + if (op.getConstant()) p << "constant "; - p.printSymbolName(op.sym_name()); + p.printSymbolName(op.getSymName()); p << '('; if (auto value = op.getValueOrNull()) p.printAttribute(value); @@ -1426,14 +1427,14 @@ // default syntax here, even though it is an inherent attribute // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes) p.printOptionalAttrDict(op->getAttrs(), - {SymbolTable::getSymbolAttrName(), "type", "constant", - "value", getLinkageAttrName(), + {SymbolTable::getSymbolAttrName(), "global_type", + "constant", "value", getLinkageAttrName(), getUnnamedAddrAttrName()}); // Print the trailing type unless it's a string global. if (op.getValueOrNull().dyn_cast_or_null()) return; - p << " : " << op.type(); + p << " : " << op.getType(); Region &initializer = op.getInitializerRegion(); if (!initializer.empty()) @@ -1546,7 +1547,7 @@ return failure(); } - result.addAttribute("type", TypeAttr::get(types[0])); + result.addAttribute("global_type", TypeAttr::get(types[0])); return success(); } @@ -1595,7 +1596,7 @@ return op.emitOpError("cannot have both initializer value and region"); } - if (op.linkage() == Linkage::Common) { + if (op.getLinkage() == Linkage::Common) { if (Attribute value = op.getValueOrNull()) { if (!isZeroAttribute(value)) { return op.emitOpError() @@ -1605,7 +1606,7 @@ } } - if (op.linkage() == Linkage::Appending) { + if (op.getLinkage() == Linkage::Appending) { if (!op.getType().isa()) { return op.emitOpError() << "expected array type for '" @@ -1613,7 +1614,7 @@ } } - Optional alignAttr = op.alignment(); + Optional alignAttr = op.getAlignment(); if (alignAttr.hasValue()) { uint64_t value = alignAttr.getValue(); if (!llvm::isPowerOf2_64(value)) @@ -1697,7 +1698,7 @@ result.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); result.addAttribute("type", TypeAttr::get(type)); - result.addAttribute(getLinkageAttrName(), + result.addAttribute(::getLinkageAttrName(), LinkageAttr::get(builder.getContext(), linkage)); result.attributes.append(attrs.begin(), attrs.end()); if (dsoLocal) @@ -1903,7 +1904,7 @@ //===----------------------------------------------------------------------===// static LogicalResult verify(LLVM::ConstantOp op) { - if (StringAttr sAttr = op.value().dyn_cast()) { + if (StringAttr sAttr = op.getValue().dyn_cast()) { auto arrayType = op.getType().dyn_cast(); if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() || !arrayType.getElementType().isInteger(8)) { @@ -1920,7 +1921,7 @@ "same type, the type of a complex constant"; } - auto arrayAttr = op.value().dyn_cast(); + auto arrayAttr = op.getValue().dyn_cast(); if (!arrayAttr || arrayAttr.size() != 2 || arrayAttr[0].getType() != arrayAttr[1].getType()) { return op.emitOpError() << "expected array attribute with two elements, " @@ -1936,7 +1937,7 @@ } return success(); } - if (!op.value().isa()) + if (!op.getValue().isa()) return op.emitOpError() << "only supports integer, float, string or elements attributes"; return success(); @@ -2004,10 +2005,10 @@ //===----------------------------------------------------------------------===// static void printAtomicRMWOp(OpAsmPrinter &p, AtomicRMWOp &op) { - p << ' ' << stringifyAtomicBinOp(op.bin_op()) << ' ' << op.ptr() << ", " - << op.val() << ' ' << stringifyAtomicOrdering(op.ordering()) << ' '; + p << ' ' << stringifyAtomicBinOp(op.getBinOp()) << ' ' << op.getPtr() << ", " + << op.getVal() << ' ' << stringifyAtomicOrdering(op.getOrdering()) << ' '; p.printOptionalAttrDict(op->getAttrs(), {"bin_op", "ordering"}); - p << " : " << op.res().getType(); + p << " : " << op.getRes().getType(); } // ::= `llvm.atomicrmw` keyword ssa-use `,` ssa-use keyword @@ -2031,19 +2032,20 @@ } static LogicalResult verify(AtomicRMWOp op) { - auto ptrType = op.ptr().getType().cast(); - auto valType = op.val().getType(); + auto ptrType = op.getPtr().getType().cast(); + auto valType = op.getVal().getType(); if (valType != ptrType.getElementType()) return op.emitOpError("expected LLVM IR element type for operand #0 to " "match type for operand #1"); - auto resType = op.res().getType(); + auto resType = op.getRes().getType(); if (resType != valType) return op.emitOpError( "expected LLVM IR result type to match type for operand #1"); - if (op.bin_op() == AtomicBinOp::fadd || op.bin_op() == AtomicBinOp::fsub) { + if (op.getBinOp() == AtomicBinOp::fadd || + op.getBinOp() == AtomicBinOp::fsub) { if (!mlir::LLVM::isCompatibleFloatingPointType(valType)) return op.emitOpError("expected LLVM IR floating point type"); - } else if (op.bin_op() == AtomicBinOp::xchg) { + } else if (op.getBinOp() == AtomicBinOp::xchg) { auto intType = valType.dyn_cast(); unsigned intBitWidth = intType ? intType.getWidth() : 0; if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && @@ -2059,7 +2061,7 @@ return op.emitOpError("expected LLVM IR integer type"); } - if (static_cast(op.ordering()) < + if (static_cast(op.getOrdering()) < static_cast(AtomicOrdering::monotonic)) return op.emitOpError() << "expected at least '" @@ -2074,12 +2076,12 @@ //===----------------------------------------------------------------------===// static void printAtomicCmpXchgOp(OpAsmPrinter &p, AtomicCmpXchgOp &op) { - p << ' ' << op.ptr() << ", " << op.cmp() << ", " << op.val() << ' ' - << stringifyAtomicOrdering(op.success_ordering()) << ' ' - << stringifyAtomicOrdering(op.failure_ordering()); + p << ' ' << op.getPtr() << ", " << op.getCmp() << ", " << op.getVal() << ' ' + << stringifyAtomicOrdering(op.getSuccessOrdering()) << ' ' + << stringifyAtomicOrdering(op.getFailureOrdering()); p.printOptionalAttrDict(op->getAttrs(), {"success_ordering", "failure_ordering"}); - p << " : " << op.val().getType(); + p << " : " << op.getVal().getType(); } // ::= `llvm.cmpxchg` ssa-use `,` ssa-use `,` ssa-use @@ -2111,11 +2113,11 @@ } static LogicalResult verify(AtomicCmpXchgOp op) { - auto ptrType = op.ptr().getType().cast(); + auto ptrType = op.getPtr().getType().cast(); if (!ptrType) return op.emitOpError("expected LLVM IR pointer type for operand #0"); - auto cmpType = op.cmp().getType(); - auto valType = op.val().getType(); + auto cmpType = op.getCmp().getType(); + auto valType = op.getVal().getType(); if (cmpType != ptrType.getElementType() || cmpType != valType) return op.emitOpError("expected LLVM IR element type for operand #0 to " "match type for all other operands"); @@ -2126,11 +2128,11 @@ !valType.isa() && !valType.isa() && !valType.isa() && !valType.isa()) return op.emitOpError("unexpected LLVM IR type"); - if (op.success_ordering() < AtomicOrdering::monotonic || - op.failure_ordering() < AtomicOrdering::monotonic) + if (op.getSuccessOrdering() < AtomicOrdering::monotonic || + op.getFailureOrdering() < AtomicOrdering::monotonic) return op.emitOpError("ordering must be at least 'monotonic'"); - if (op.failure_ordering() == AtomicOrdering::release || - op.failure_ordering() == AtomicOrdering::acq_rel) + if (op.getFailureOrdering() == AtomicOrdering::release || + op.getFailureOrdering() == AtomicOrdering::acq_rel) return op.emitOpError("failure ordering cannot be 'release' or 'acq_rel'"); return success(); } @@ -2164,13 +2166,13 @@ p << ' '; if (!op->getAttr(syncscopeKeyword).cast().getValue().empty()) p << "syncscope(" << op->getAttr(syncscopeKeyword) << ") "; - p << stringifyAtomicOrdering(op.ordering()); + p << stringifyAtomicOrdering(op.getOrdering()); } static LogicalResult verify(FenceOp &op) { - if (op.ordering() == AtomicOrdering::not_atomic || - op.ordering() == AtomicOrdering::unordered || - op.ordering() == AtomicOrdering::monotonic) + if (op.getOrdering() == AtomicOrdering::not_atomic || + op.getOrdering() == AtomicOrdering::unordered || + op.getOrdering() == AtomicOrdering::monotonic) return op.emitOpError("can be given only acquire, release, acq_rel, " "and seq_cst orderings"); return success(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -61,7 +61,7 @@ Attribute attr = ofr.dyn_cast(); // Note: isa+cast-like pattern allows writing the condition below as 1 line. if (!attr && ofr.get().getDefiningOp()) - attr = ofr.get().getDefiningOp().value(); + attr = ofr.get().getDefiningOp().getValue(); if (auto intAttr = attr.dyn_cast_or_null()) return intAttr.getValue().getSExtValue(); return llvm::None; diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -187,7 +187,7 @@ : cst.getValue(); } else if (auto constIndexOp = size.getDefiningOp()) { if (constIndexOp.getType().isa()) - boundingConst = constIndexOp.value().cast().getInt(); + boundingConst = constIndexOp.getValue().cast().getInt(); } else if (auto affineApplyOp = size.getDefiningOp()) { if (auto cExpr = affineApplyOp.getAffineMap() .getResult(0) @@ -196,7 +196,7 @@ } else if (auto dimOp = size.getDefiningOp()) { auto shape = dimOp.source().getType().dyn_cast(); if (auto constOp = dimOp.index().getDefiningOp()) { - if (auto indexAttr = constOp.value().dyn_cast()) { + if (auto indexAttr = constOp.getValue().dyn_cast()) { auto dimIndex = indexAttr.getInt(); if (!shape.isDynamicDim(dimIndex)) { boundingConst = shape.getShape()[dimIndex]; diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -691,7 +691,7 @@ Optional DimOp::getConstantIndex() { if (auto constantOp = index().getDefiningOp()) - return constantOp.value().cast().getInt(); + return constantOp.getValue().cast().getInt(); return {}; } diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -173,7 +173,7 @@ return success(); auto constOp = op.ifCond().template getDefiningOp(); - if (constOp && constOp.value().template cast().getInt()) + if (constOp && constOp.getValue().template cast().getInt()) rewriter.updateRootInPlace(op, [&]() { op.ifCondMutable().erase(0); }); else if (constOp) rewriter.eraseOp(op); diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -714,8 +714,8 @@ return failure(); // If the loop is known to have 0 iterations, remove it. - llvm::APInt lbValue = lb.value().cast().getValue(); - llvm::APInt ubValue = ub.value().cast().getValue(); + llvm::APInt lbValue = lb.getValue().cast().getValue(); + llvm::APInt ubValue = ub.getValue().cast().getValue(); if (lbValue.sge(ubValue)) { rewriter.replaceOp(op, op.getIterOperands()); return success(); @@ -727,7 +727,7 @@ // If the loop is known to have 1 iteration, inline its body and remove the // loop. - llvm::APInt stepValue = step.value().cast().getValue(); + llvm::APInt stepValue = step.getValue().cast().getValue(); if ((lbValue + stepValue).sge(ubValue)) { SmallVector blockArgs; blockArgs.reserve(op.getNumIterOperands() + 1); @@ -1241,7 +1241,7 @@ if (!constant) return failure(); - if (constant.value().cast().getValue()) + if (constant.getValue().cast().getValue()) replaceOpWithRegion(rewriter, op, op.thenRegion()); else if (!op.elseRegion().empty()) replaceOpWithRegion(rewriter, op, op.elseRegion()); @@ -1425,8 +1425,8 @@ if (!falseYield) continue; - bool trueVal = trueYield.value().cast().getValue(); - bool falseVal = falseYield.value().cast().getValue(); + bool trueVal = trueYield.getValue().cast().getValue(); + bool falseVal = falseYield.getValue().cast().getValue(); if (!trueVal && falseVal) { if (!opResult.use_empty()) { Value notCond = rewriter.create( diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -45,17 +45,17 @@ LogicalResult shape::getShapeVec(Value input, SmallVectorImpl &shapeValues) { if (auto inputOp = input.getDefiningOp()) { - auto type = inputOp.arg().getType().dyn_cast(); + auto type = inputOp.getArg().getType().dyn_cast(); if (!type.hasRank()) return failure(); shapeValues = llvm::to_vector<6>(type.getShape()); return success(); } else if (auto inputOp = input.getDefiningOp()) { - shapeValues = llvm::to_vector<6>(inputOp.shape().getValues()); + shapeValues = llvm::to_vector<6>(inputOp.getShape().getValues()); return success(); } else if (auto inputOp = input.getDefiningOp()) { shapeValues = llvm::to_vector<6>( - inputOp.value().cast().getValues()); + inputOp.getValue().cast().getValues()); return success(); } else { return failure(); @@ -218,7 +218,7 @@ if (!shapeFnLib) return op->emitError() << it << " does not refer to FunctionLibraryOp"; - for (auto mapping : shapeFnLib.mapping()) { + for (auto mapping : shapeFnLib.getMapping()) { if (!key.insert(mapping.first).second) { return op->emitError("only one op to shape mapping allowed, found " "multiple for `") @@ -281,13 +281,13 @@ } static void print(OpAsmPrinter &p, AssumingOp op) { - bool yieldsResults = !op.results().empty(); + bool yieldsResults = !op.getResults().empty(); - p << " " << op.witness(); + p << " " << op.getWitness(); if (yieldsResults) { p << " -> (" << op.getResultTypes() << ")"; } - p.printRegion(op.doRegion(), + p.printRegion(op.getDoRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/yieldsResults); p.printOptionalAttrDict(op->getAttrs()); @@ -300,8 +300,8 @@ LogicalResult matchAndRewrite(AssumingOp op, PatternRewriter &rewriter) const override { - auto witness = op.witness().getDefiningOp(); - if (!witness || !witness.passingAttr()) + auto witness = op.getWitness().getDefiningOp(); + if (!witness || !witness.getPassingAttr()) return failure(); AssumingOp::inlineRegionIntoParent(op, rewriter); @@ -320,7 +320,7 @@ // Find used values. SmallVector newYieldOperands; Value opResult, yieldOperand; - for (auto it : llvm::zip(op.getResults(), yieldOp.operands())) { + for (auto it : llvm::zip(op.getResults(), yieldOp.getOperands())) { std::tie(opResult, yieldOperand) = it; if (!opResult.getUses().empty()) { newYieldOperands.push_back(yieldOperand); @@ -338,8 +338,8 @@ rewriter.replaceOpWithNewOp(yieldOp, newYieldOperands); rewriter.setInsertionPoint(op); auto newOp = rewriter.create( - op.getLoc(), newYieldOp->getOperandTypes(), op.witness()); - newOp.doRegion().takeBody(op.doRegion()); + op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness()); + newOp.getDoRegion().takeBody(op.getDoRegion()); // Use the new results to replace the previously used ones. SmallVector replacementValues; @@ -373,7 +373,7 @@ return; } - regions.push_back(RegionSuccessor(&doRegion())); + regions.push_back(RegionSuccessor(&getDoRegion())); } void AssumingOp::inlineRegionIntoParent(AssumingOp &op, @@ -386,7 +386,7 @@ // Remove the AssumingOp and AssumingYieldOp. auto &yieldOp = assumingBlock->back(); - rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming); + rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming); rewriter.replaceOp(op, yieldOp.getOperands()); rewriter.eraseOp(&yieldOp); @@ -440,8 +440,8 @@ OpFoldResult mlir::shape::AddOp::fold(ArrayRef operands) { // add(x, 0) -> x - if (matchPattern(rhs(), m_Zero())) - return lhs(); + if (matchPattern(getRhs(), m_Zero())) + return getLhs(); return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a + b; }); @@ -459,16 +459,16 @@ LogicalResult matchAndRewrite(AssumingAllOp op, PatternRewriter &rewriter) const override { SmallVector shapes; - for (Value w : op.inputs()) { + for (Value w : op.getInputs()) { auto cstrEqOp = w.getDefiningOp(); if (!cstrEqOp) return failure(); - bool disjointShapes = llvm::none_of(cstrEqOp.shapes(), [&](Value s) { + bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) { return llvm::is_contained(shapes, s); }); - if (!shapes.empty() && !cstrEqOp.shapes().empty() && disjointShapes) + if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes) return failure(); - shapes.append(cstrEqOp.shapes().begin(), cstrEqOp.shapes().end()); + shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end()); } rewriter.replaceOpWithNewOp(op, shapes); return success(); @@ -545,15 +545,15 @@ //===----------------------------------------------------------------------===// OpFoldResult BroadcastOp::fold(ArrayRef operands) { - if (shapes().size() == 1) { + if (getShapes().size() == 1) { // Otherwise, we need a cast which would be a canonicalization, not folding. - if (shapes().front().getType() != getType()) + if (getShapes().front().getType() != getType()) return nullptr; - return shapes().front(); + return getShapes().front(); } // TODO: Support folding with more than 2 input shapes - if (shapes().size() > 2) + if (getShapes().size() > 2) return nullptr; if (!operands[0] || !operands[1]) @@ -590,7 +590,7 @@ return false; } if (auto constShape = shape.getDefiningOp()) { - if (constShape.shape().empty()) + if (constShape.getShape().empty()) return false; } return true; @@ -617,7 +617,7 @@ PatternRewriter &rewriter) const override { if (op.getNumOperands() != 1) return failure(); - Value replacement = op.shapes().front(); + Value replacement = op.getShapes().front(); // Insert cast if needed. if (replacement.getType() != op.getType()) { @@ -646,12 +646,12 @@ PatternRewriter &rewriter) const override { SmallVector foldedConstantShape; SmallVector newShapeOperands; - for (Value shape : op.shapes()) { + for (Value shape : op.getShapes()) { if (auto constShape = shape.getDefiningOp()) { SmallVector newFoldedConstantShape; if (OpTrait::util::getBroadcastedShape( foldedConstantShape, - llvm::to_vector<8>(constShape.shape().getValues()), + llvm::to_vector<8>(constShape.getShape().getValues()), newFoldedConstantShape)) { foldedConstantShape = newFoldedConstantShape; continue; @@ -721,7 +721,7 @@ // Infer resulting shape rank if possible. int64_t maxRank = 0; - for (Value shape : op.shapes()) { + for (Value shape : op.getShapes()) { if (auto extentTensorTy = shape.getType().dyn_cast()) { // Cannot infer resulting shape rank if any operand is dynamically // ranked. @@ -732,7 +732,8 @@ } auto newOp = rewriter.create( - op.getLoc(), getExtentTensorType(getContext(), maxRank), op.shapes()); + op.getLoc(), getExtentTensorType(getContext(), maxRank), + op.getShapes()); rewriter.replaceOpWithNewOp(op, op.getType(), newOp); return success(); } @@ -775,7 +776,7 @@ p << " "; p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"}); p << "["; - interleaveComma(op.shape().getValues(), p, + interleaveComma(op.getShape().getValues(), p, [&](int64_t i) { p << i; }); p << "] : "; p.printType(op.getType()); @@ -811,7 +812,7 @@ return success(); } -OpFoldResult ConstShapeOp::fold(ArrayRef) { return shapeAttr(); } +OpFoldResult ConstShapeOp::fold(ArrayRef) { return getShapeAttr(); } void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { @@ -895,7 +896,7 @@ // on the input shapes. if ([&] { SmallVector, 6> extents; - for (auto shapeValue : shapes()) { + for (auto shapeValue : getShapes()) { extents.emplace_back(); if (failed(getShapeVec(shapeValue, extents.back()))) return false; @@ -946,13 +947,13 @@ build(builder, result, builder.getIndexAttr(value)); } -OpFoldResult ConstSizeOp::fold(ArrayRef) { return valueAttr(); } +OpFoldResult ConstSizeOp::fold(ArrayRef) { return getValueAttr(); } void ConstSizeOp::getAsmResultNames( llvm::function_ref setNameFn) { SmallString<4> buffer; llvm::raw_svector_ostream os(buffer); - os << "c" << value(); + os << "c" << getValue(); setNameFn(getResult(), os.str()); } @@ -960,7 +961,9 @@ // ConstWitnessOp //===----------------------------------------------------------------------===// -OpFoldResult ConstWitnessOp::fold(ArrayRef) { return passingAttr(); } +OpFoldResult ConstWitnessOp::fold(ArrayRef) { + return getPassingAttr(); +} //===----------------------------------------------------------------------===// // CstrRequireOp @@ -1069,7 +1072,7 @@ } FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { - auto attr = mapping() + auto attr = getMapping() .get(op->getName().getIdentifier()) .dyn_cast_or_null(); if (!attr) @@ -1111,7 +1114,7 @@ p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/false); p << " mapping "; - p.printAttributeWithoutType(op.mappingAttr()); + p.printAttributeWithoutType(op.getMappingAttr()); } //===----------------------------------------------------------------------===// @@ -1119,10 +1122,10 @@ //===----------------------------------------------------------------------===// Optional GetExtentOp::getConstantDim() { - if (auto constSizeOp = dim().getDefiningOp()) - return constSizeOp.value().getLimitedValue(); - if (auto constantOp = dim().getDefiningOp()) - return constantOp.value().cast().getInt(); + if (auto constSizeOp = getDim().getDefiningOp()) + return constSizeOp.getValue().getLimitedValue(); + if (auto constantOp = getDim().getDefiningOp()) + return constantOp.getValue().cast().getInt(); return llvm::None; } @@ -1250,11 +1253,11 @@ LogicalResult matchAndRewrite(shape::RankOp op, PatternRewriter &rewriter) const override { - auto shapeOfOp = op.shape().getDefiningOp(); + auto shapeOfOp = op.getShape().getDefiningOp(); if (!shapeOfOp) return failure(); auto rankedTensorType = - shapeOfOp.arg().getType().dyn_cast(); + shapeOfOp.getArg().getType().dyn_cast(); if (!rankedTensorType) return failure(); int64_t rank = rankedTensorType.getRank(); @@ -1333,8 +1336,8 @@ OpFoldResult MaxOp::fold(llvm::ArrayRef operands) { // If operands are equal, just propagate one. - if (lhs() == rhs()) - return lhs(); + if (getLhs() == getRhs()) + return getLhs(); return nullptr; } @@ -1365,8 +1368,8 @@ OpFoldResult MinOp::fold(llvm::ArrayRef operands) { // If operands are equal, just propagate one. - if (lhs() == rhs()) - return lhs(); + if (getLhs() == getRhs()) + return getLhs(); return nullptr; } @@ -1441,12 +1444,13 @@ LogicalResult matchAndRewrite(shape::ShapeOfOp op, PatternRewriter &rewriter) const override { - if (!op.arg().getType().isa()) + if (!op.getArg().getType().isa()) return failure(); if (op.getType().isa()) return failure(); - rewriter.replaceOpWithNewOp(op.getOperation(), op.arg()); + rewriter.replaceOpWithNewOp(op.getOperation(), + op.getArg()); return success(); } }; @@ -1474,11 +1478,11 @@ return failure(); // Argument type must be ranked and must not conflict. - auto argTy = shapeOfOp.arg().getType().dyn_cast(); + auto argTy = shapeOfOp.getArg().getType().dyn_cast(); if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank())) return failure(); - rewriter.replaceOpWithNewOp(op, ty, shapeOfOp.arg()); + rewriter.replaceOpWithNewOp(op, ty, shapeOfOp.getArg()); return success(); } }; @@ -1634,10 +1638,10 @@ static LogicalResult verify(ReduceOp op) { // Verify block arg types. - Block &block = op.region().front(); + Block &block = op.getRegion().front(); // The block takes index, extent, and aggregated values as arguments. - auto blockArgsCount = op.initVals().size() + 2; + auto blockArgsCount = op.getInitVals().size() + 2; if (block.getNumArguments() != blockArgsCount) return op.emitOpError() << "ReduceOp body is expected to have " << blockArgsCount << " arguments"; @@ -1651,7 +1655,7 @@ // `index`, depending on whether the reduce operation is applied to a shape or // to an extent tensor. Type extentTy = block.getArgument(1).getType(); - if (op.shape().getType().isa()) { + if (op.getShape().getType().isa()) { if (!extentTy.isa()) return op.emitOpError("argument 1 of ReduceOp body is expected to be of " "SizeType if the ReduceOp operates on a ShapeType"); @@ -1662,7 +1666,7 @@ "ReduceOp operates on an extent tensor"); } - for (auto type : llvm::enumerate(op.initVals())) + for (auto type : llvm::enumerate(op.getInitVals())) if (block.getArgument(type.index() + 2).getType() != type.value().getType()) return op.emitOpError() << "type mismatch between argument " << type.index() + 2 @@ -1701,10 +1705,10 @@ } static void print(OpAsmPrinter &p, ReduceOp op) { - p << '(' << op.shape() << ", " << op.initVals() - << ") : " << op.shape().getType(); + p << '(' << op.getShape() << ", " << op.getInitVals() + << ") : " << op.getShape().getType(); p.printOptionalArrowTypeList(op.getResultTypes()); - p.printRegion(op.region()); + p.printRegion(op.getRegion()); p.printOptionalAttrDict(op->getAttrs()); } diff --git a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp --- a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp @@ -39,7 +39,7 @@ ->materializeConstant(rewriter, rewriter.getIndexAttr(1), valueType, loc) ->getResult(0); - ReduceOp reduce = rewriter.create(loc, op.shape(), init); + ReduceOp reduce = rewriter.create(loc, op.getShape(), init); // Generate reduce operator. Block *body = reduce.getBody(); @@ -48,7 +48,7 @@ body->getArgument(2)); b.create(loc, product); - rewriter.replaceOp(op, reduce.result()); + rewriter.replaceOp(op, reduce.getResult()); return success(); } diff --git a/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp --- a/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp @@ -30,10 +30,10 @@ newResultTypes.push_back(convertedType); } - auto newAssumingOp = - rewriter.create(op.getLoc(), newResultTypes, op.witness()); - rewriter.inlineRegionBefore(op.doRegion(), newAssumingOp.doRegion(), - newAssumingOp.doRegion().end()); + auto newAssumingOp = rewriter.create( + op.getLoc(), newResultTypes, op.getWitness()); + rewriter.inlineRegionBefore(op.getDoRegion(), newAssumingOp.getDoRegion(), + newAssumingOp.getDoRegion().end()); rewriter.replaceOp(op, newAssumingOp.getResults()); return success(); diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -193,7 +193,7 @@ static LogicalResult isInBounds(Value dim, Value tensor) { if (auto constantOp = dim.getDefiningOp()) { - unsigned d = constantOp.value().cast().getInt(); + unsigned d = constantOp.getValue().cast().getInt(); if (d >= tensor.getType().cast().getRank()) return failure(); } @@ -227,7 +227,7 @@ continue; auto constantOp = op.sizes()[i].getDefiningOp(); if (!constantOp || - constantOp.value().cast().getInt() != shape[i]) + constantOp.getValue().cast().getInt() != shape[i]) return op.emitError("unexpected mismatch with static dimension size ") << shape[i]; } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -351,7 +351,7 @@ genSplitSparseConstant(ConversionPatternRewriter &rewriter, Location loc, Value tensor) { if (auto constOp = tensor.getDefiningOp()) { - if (auto attr = constOp.value().dyn_cast()) { + if (auto attr = constOp.getValue().dyn_cast()) { DenseElementsAttr indicesAttr = attr.getIndices(); Value indices = rewriter.create(loc, indicesAttr); DenseElementsAttr valuesAttr = attr.getValues(); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -146,7 +146,7 @@ LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) { // Erase assertion if argument is constant true. - if (matchPattern(op.arg(), m_One())) { + if (matchPattern(op.getArg(), m_One())) { rewriter.eraseOp(op); return success(); } @@ -161,14 +161,14 @@ if (op.getMemRefType().getRank() != op.getNumOperands() - 2) return op.emitOpError( "expects the number of subscripts to be equal to memref rank"); - switch (op.kind()) { + switch (op.getKind()) { case AtomicRMWKind::addf: case AtomicRMWKind::maxf: case AtomicRMWKind::minf: case AtomicRMWKind::mulf: - if (!op.value().getType().isa()) + if (!op.getValue().getType().isa()) return op.emitOpError() - << "with kind '" << stringifyAtomicRMWKind(op.kind()) + << "with kind '" << stringifyAtomicRMWKind(op.getKind()) << "' expects a floating-point type"; break; case AtomicRMWKind::addi: @@ -177,9 +177,9 @@ case AtomicRMWKind::mins: case AtomicRMWKind::minu: case AtomicRMWKind::muli: - if (!op.value().getType().isa()) + if (!op.getValue().getType().isa()) return op.emitOpError() - << "with kind '" << stringifyAtomicRMWKind(op.kind()) + << "with kind '" << stringifyAtomicRMWKind(op.getKind()) << "' expects an integer type"; break; default: @@ -308,7 +308,7 @@ } static LogicalResult verify(GenericAtomicRMWOp op) { - auto &body = op.body(); + auto &body = op.getRegion(); if (body.getNumArguments() != 1) return op.emitOpError("expected single number of entry block arguments"); @@ -351,9 +351,9 @@ } static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) { - p << ' ' << op.memref() << "[" << op.indices() - << "] : " << op.memref().getType(); - p.printRegion(op.body()); + p << ' ' << op.getMemref() << "[" << op.getIndices() + << "] : " << op.getMemref().getType(); + p.printRegion(op.getRegion()); p.printOptionalAttrDict(op->getAttrs()); } @@ -363,7 +363,7 @@ static LogicalResult verify(AtomicYieldOp op) { Type parentType = op->getParentOp()->getResultTypes().front(); - Type resultType = op.result().getType(); + Type resultType = op.getResult().getType(); if (parentType != resultType) return op.emitOpError() << "types mismatch between yield op: " << resultType << " and its parent: " << parentType; @@ -467,8 +467,6 @@ succeeded(simplifyPassThroughBr(op, rewriter))); } -Block *BranchOp::getDest() { return getSuccessor(); } - void BranchOp::setDest(Block *block) { return setSuccessor(block); } void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); } @@ -476,10 +474,12 @@ Optional BranchOp::getMutableSuccessorOperands(unsigned index) { assert(index == 0 && "invalid successor index"); - return destOperandsMutable(); + return getDestOperandsMutable(); } -Block *BranchOp::getSuccessorForOperands(ArrayRef) { return dest(); } +Block *BranchOp::getSuccessorForOperands(ArrayRef) { + return getDest(); +} //===----------------------------------------------------------------------===// // CallOp @@ -602,7 +602,7 @@ LogicalResult matchAndRewrite(CondBranchOp condbr, PatternRewriter &rewriter) const override { - Block *trueDest = condbr.trueDest(), *falseDest = condbr.falseDest(); + Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest(); ValueRange trueDestOperands = condbr.getTrueOperands(); ValueRange falseDestOperands = condbr.getFalseOperands(); SmallVector trueDestOperandStorage, falseDestOperandStorage; @@ -638,8 +638,8 @@ PatternRewriter &rewriter) const override { // Check that the true and false destinations are the same and have the same // operands. - Block *trueDest = condbr.trueDest(); - if (trueDest != condbr.falseDest()) + Block *trueDest = condbr.getTrueDest(); + if (trueDest != condbr.getFalseDest()) return failure(); // If all of the operands match, no selects need to be generated. @@ -707,12 +707,12 @@ return failure(); // Fold this branch to an unconditional branch. - if (currentBlock == predBranch.trueDest()) - rewriter.replaceOpWithNewOp(condbr, condbr.trueDest(), - condbr.trueDestOperands()); + if (currentBlock == predBranch.getTrueDest()) + rewriter.replaceOpWithNewOp(condbr, condbr.getTrueDest(), + condbr.getTrueDestOperands()); else - rewriter.replaceOpWithNewOp(condbr, condbr.falseDest(), - condbr.falseDestOperands()); + rewriter.replaceOpWithNewOp(condbr, condbr.getFalseDest(), + condbr.getFalseDestOperands()); return success(); } }; @@ -758,7 +758,7 @@ // op. if (condbr.getTrueDest()->getSinglePredecessor()) { for (OpOperand &use : - llvm::make_early_inc_range(condbr.condition().getUses())) { + llvm::make_early_inc_range(condbr.getCondition().getUses())) { if (use.getOwner()->getBlock() == condbr.getTrueDest()) { replaced = true; @@ -773,7 +773,7 @@ } if (condbr.getFalseDest()->getSinglePredecessor()) { for (OpOperand &use : - llvm::make_early_inc_range(condbr.condition().getUses())) { + llvm::make_early_inc_range(condbr.getCondition().getUses())) { if (use.getOwner()->getBlock() == condbr.getFalseDest()) { replaced = true; @@ -802,13 +802,13 @@ Optional CondBranchOp::getMutableSuccessorOperands(unsigned index) { assert(index < getNumSuccessors() && "invalid successor index"); - return index == trueIndex ? trueDestOperandsMutable() - : falseDestOperandsMutable(); + return index == trueIndex ? getTrueDestOperandsMutable() + : getFalseDestOperandsMutable(); } Block *CondBranchOp::getSuccessorForOperands(ArrayRef operands) { if (IntegerAttr condAttr = operands.front().dyn_cast_or_null()) - return condAttr.getValue().isOneValue() ? trueDest() : falseDest(); + return condAttr.getValue().isOneValue() ? getTrueDest() : getFalseDest(); return nullptr; } @@ -947,19 +947,19 @@ assert(operands.size() == 2 && "binary operation takes two operands"); // maxsi(x,x) -> x - if (lhs() == rhs()) - return rhs(); + if (getLhs() == getRhs()) + return getRhs(); APInt intValue; // maxsi(x,MAX_INT) -> MAX_INT - if (matchPattern(rhs(), m_ConstantInt(&intValue)) && + if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxSignedValue()) - return rhs(); + return getRhs(); // maxsi(x, MIN_INT) -> x - if (matchPattern(rhs(), m_ConstantInt(&intValue)) && + if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinSignedValue()) - return lhs(); + return getLhs(); return constFoldBinaryOp( operands, [](APInt a, APInt b) { return llvm::APIntOps::smax(a, b); }); @@ -973,17 +973,17 @@ assert(operands.size() == 2 && "binary operation takes two operands"); // maxui(x,x) -> x - if (lhs() == rhs()) - return rhs(); + if (getLhs() == getRhs()) + return getRhs(); APInt intValue; // maxui(x,MAX_INT) -> MAX_INT - if (matchPattern(rhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) - return rhs(); + if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) + return getRhs(); // maxui(x, MIN_INT) -> x - if (matchPattern(rhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) - return lhs(); + if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) + return getLhs(); return constFoldBinaryOp( operands, [](APInt a, APInt b) { return llvm::APIntOps::umax(a, b); }); @@ -997,19 +997,19 @@ assert(operands.size() == 2 && "binary operation takes two operands"); // minsi(x,x) -> x - if (lhs() == rhs()) - return rhs(); + if (getLhs() == getRhs()) + return getRhs(); APInt intValue; // minsi(x,MIN_INT) -> MIN_INT - if (matchPattern(rhs(), m_ConstantInt(&intValue)) && + if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinSignedValue()) - return rhs(); + return getRhs(); // minsi(x, MAX_INT) -> x - if (matchPattern(rhs(), m_ConstantInt(&intValue)) && + if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxSignedValue()) - return lhs(); + return getLhs(); return constFoldBinaryOp( operands, [](APInt a, APInt b) { return llvm::APIntOps::smin(a, b); }); @@ -1023,17 +1023,17 @@ assert(operands.size() == 2 && "binary operation takes two operands"); // minui(x,x) -> x - if (lhs() == rhs()) - return rhs(); + if (getLhs() == getRhs()) + return getRhs(); APInt intValue; // minui(x,MIN_INT) -> MIN_INT - if (matchPattern(rhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) - return rhs(); + if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) + return getRhs(); // minui(x, MAX_INT) -> x - if (matchPattern(rhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) - return lhs(); + if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) + return getLhs(); return constFoldBinaryOp( operands, [](APInt a, APInt b) { return llvm::APIntOps::umin(a, b); }); @@ -1103,7 +1103,7 @@ if (!op.getType().isInteger(1)) return failure(); - rewriter.replaceOpWithNewOp(op, op.condition(), + rewriter.replaceOpWithNewOp(op, op.getCondition(), op.getFalseValue()); return success(); } @@ -1131,10 +1131,10 @@ return falseVal; if (auto cmp = dyn_cast_or_null(condition.getDefiningOp())) { - auto pred = cmp.predicate(); + auto pred = cmp.getPredicate(); if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) { - auto cmpLhs = cmp.lhs(); - auto cmpRhs = cmp.rhs(); + auto cmpLhs = cmp.getLhs(); + auto cmpRhs = cmp.getRhs(); // %0 = arith.cmpi eq, %arg0, %arg1 // %1 = select %0, %arg0, %arg1 => %arg1 @@ -1334,13 +1334,13 @@ } static LogicalResult verify(SwitchOp op) { - auto caseValues = op.case_values(); - auto caseDestinations = op.caseDestinations(); + auto caseValues = op.getCaseValues(); + auto caseDestinations = op.getCaseDestinations(); if (!caseValues && caseDestinations.empty()) return success(); - Type flagType = op.flag().getType(); + Type flagType = op.getFlag().getType(); Type caseValueType = caseValues->getType().getElementType(); if (caseValueType != flagType) return op.emitOpError() @@ -1359,22 +1359,22 @@ Optional SwitchOp::getMutableSuccessorOperands(unsigned index) { assert(index < getNumSuccessors() && "invalid successor index"); - return index == 0 ? defaultOperandsMutable() + return index == 0 ? getDefaultOperandsMutable() : getCaseOperandsMutable(index - 1); } Block *SwitchOp::getSuccessorForOperands(ArrayRef operands) { - Optional caseValues = case_values(); + Optional caseValues = getCaseValues(); if (!caseValues) - return defaultDestination(); + return getDefaultDestination(); - SuccessorRange caseDests = caseDestinations(); + SuccessorRange caseDests = getCaseDestinations(); if (auto value = operands.front().dyn_cast_or_null()) { - for (int64_t i = 0, size = case_values()->size(); i < size; ++i) + for (int64_t i = 0, size = getCaseValues()->size(); i < size; ++i) if (value == caseValues->getValue(i)) return caseDests[i]; - return defaultDestination(); + return getDefaultDestination(); } return nullptr; } @@ -1385,11 +1385,11 @@ /// -> br ^bb1 static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, PatternRewriter &rewriter) { - if (!op.caseDestinations().empty()) + if (!op.getCaseDestinations().empty()) return failure(); - rewriter.replaceOpWithNewOp(op, op.defaultDestination(), - op.defaultOperands()); + rewriter.replaceOpWithNewOp(op, op.getDefaultDestination(), + op.getDefaultOperands()); return success(); } @@ -1409,12 +1409,12 @@ SmallVector newCaseOperands; SmallVector newCaseValues; bool requiresChange = false; - auto caseValues = op.case_values(); - auto caseDests = op.caseDestinations(); + auto caseValues = op.getCaseValues(); + auto caseDests = op.getCaseDestinations(); for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { - if (caseDests[i] == op.defaultDestination() && - op.getCaseOperands(i) == op.defaultOperands()) { + if (caseDests[i] == op.getDefaultDestination() && + op.getCaseOperands(i) == op.getDefaultOperands()) { requiresChange = true; continue; } @@ -1426,9 +1426,9 @@ if (!requiresChange) return failure(); - rewriter.replaceOpWithNewOp(op, op.flag(), op.defaultDestination(), - op.defaultOperands(), newCaseValues, - newCaseDestinations, newCaseOperands); + rewriter.replaceOpWithNewOp( + op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(), + newCaseValues, newCaseDestinations, newCaseOperands); return success(); } @@ -1441,16 +1441,16 @@ /// -> br ^bb2 static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, APInt caseValue) { - auto caseValues = op.case_values(); + auto caseValues = op.getCaseValues(); for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { if (caseValues->getValue(i) == caseValue) { - rewriter.replaceOpWithNewOp(op, op.caseDestinations()[i], + rewriter.replaceOpWithNewOp(op, op.getCaseDestinations()[i], op.getCaseOperands(i)); return; } } - rewriter.replaceOpWithNewOp(op, op.defaultDestination(), - op.defaultOperands()); + rewriter.replaceOpWithNewOp(op, op.getDefaultDestination(), + op.getDefaultOperands()); } /// switch %c_42 : i32, [ @@ -1462,7 +1462,7 @@ static LogicalResult simplifyConstSwitchValue(SwitchOp op, PatternRewriter &rewriter) { APInt caseValue; - if (!matchPattern(op.flag(), m_ConstantInt(&caseValue))) + if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue))) return failure(); foldSwitch(op, rewriter, caseValue); @@ -1485,8 +1485,8 @@ SmallVector newCaseDests; SmallVector newCaseOperands; SmallVector> argStorage; - auto caseValues = op.case_values(); - auto caseDests = op.caseDestinations(); + auto caseValues = op.getCaseValues(); + auto caseDests = op.getCaseDestinations(); bool requiresChange = false; for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { Block *caseDest = caseDests[i]; @@ -1499,8 +1499,8 @@ newCaseOperands.push_back(caseOperands); } - Block *defaultDest = op.defaultDestination(); - ValueRange defaultOperands = op.defaultOperands(); + Block *defaultDest = op.getDefaultDestination(); + ValueRange defaultOperands = op.getDefaultOperands(); argStorage.emplace_back(); if (succeeded( @@ -1510,7 +1510,7 @@ if (!requiresChange) return failure(); - rewriter.replaceOpWithNewOp(op, op.flag(), defaultDest, + rewriter.replaceOpWithNewOp(op, op.getFlag(), defaultDest, defaultOperands, caseValues.getValue(), newCaseDests, newCaseOperands); return success(); @@ -1564,15 +1564,15 @@ // and that it branches on the same condition and that this branch isn't the // default destination. auto predSwitch = dyn_cast(predecessor->getTerminator()); - if (!predSwitch || op.flag() != predSwitch.flag() || - predSwitch.defaultDestination() == currentBlock) + if (!predSwitch || op.getFlag() != predSwitch.getFlag() || + predSwitch.getDefaultDestination() == currentBlock) return failure(); // Fold this switch to an unconditional branch. APInt caseValue; bool isDefault = true; - SuccessorRange predDests = predSwitch.caseDestinations(); - Optional predCaseValues = predSwitch.case_values(); + SuccessorRange predDests = predSwitch.getCaseDestinations(); + Optional predCaseValues = predSwitch.getCaseValues(); for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) { if (currentBlock == predDests[i]) { caseValue = predCaseValues->getValue(i); @@ -1581,8 +1581,8 @@ } } if (isDefault) - rewriter.replaceOpWithNewOp(op, op.defaultDestination(), - op.defaultOperands()); + rewriter.replaceOpWithNewOp(op, op.getDefaultDestination(), + op.getDefaultOperands()); else foldSwitch(op, rewriter, caseValue); return success(); @@ -1621,14 +1621,14 @@ // and that it branches on the same condition and that this branch is the // default destination. auto predSwitch = dyn_cast(predecessor->getTerminator()); - if (!predSwitch || op.flag() != predSwitch.flag() || - predSwitch.defaultDestination() != currentBlock) + if (!predSwitch || op.getFlag() != predSwitch.getFlag() || + predSwitch.getDefaultDestination() != currentBlock) return failure(); // Delete case values that are not possible here. DenseSet caseValuesToRemove; - auto predDests = predSwitch.caseDestinations(); - auto predCaseValues = predSwitch.case_values(); + auto predDests = predSwitch.getCaseDestinations(); + auto predCaseValues = predSwitch.getCaseValues(); for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) if (currentBlock != predDests[i]) caseValuesToRemove.insert(predCaseValues->getValue(i)); @@ -1638,8 +1638,8 @@ SmallVector newCaseValues; bool requiresChange = false; - auto caseValues = op.case_values(); - auto caseDests = op.caseDestinations(); + auto caseValues = op.getCaseValues(); + auto caseDests = op.getCaseDestinations(); for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { if (caseValuesToRemove.contains(caseValues->getValue(i))) { requiresChange = true; @@ -1653,9 +1653,9 @@ if (!requiresChange) return failure(); - rewriter.replaceOpWithNewOp(op, op.flag(), op.defaultDestination(), - op.defaultOperands(), newCaseValues, - newCaseDestinations, newCaseOperands); + rewriter.replaceOpWithNewOp( + op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(), + newCaseValues, newCaseDestinations, newCaseOperands); return success(); } diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp @@ -29,11 +29,12 @@ LogicalResult matchAndRewrite(SelectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!op.condition().getType().isa()) + if (!op.getCondition().getType().isa()) return rewriter.notifyMatchFailure(op, "requires scalar condition"); - rewriter.replaceOpWithNewOp( - op, adaptor.condition(), adaptor.true_value(), adaptor.false_value()); + rewriter.replaceOpWithNewOp(op, adaptor.getCondition(), + adaptor.getTrueValue(), + adaptor.getFalseValue()); return success(); } }; @@ -61,7 +62,7 @@ // touch the data). target.addDynamicallyLegalOp([&](SelectOp op) { return typeConverter.isLegal(op.getType()) || - !op.condition().getType().isa(); + !op.getCondition().getType().isa(); }); if (failed( applyPartialConversion(getFunction(), target, std::move(patterns)))) diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp @@ -46,7 +46,7 @@ LogicalResult matchAndRewrite(AtomicRMWOp op, PatternRewriter &rewriter) const final { arith::CmpFPredicate predicate; - switch (op.kind()) { + switch (op.getKind()) { case AtomicRMWKind::maxf: predicate = arith::CmpFPredicate::OGT; break; @@ -58,13 +58,13 @@ } auto loc = op.getLoc(); - auto genericOp = - rewriter.create(loc, op.memref(), op.indices()); + auto genericOp = rewriter.create(loc, op.getMemref(), + op.getIndices()); OpBuilder bodyBuilder = OpBuilder::atBlockEnd(genericOp.getBody(), rewriter.getListener()); Value lhs = genericOp.getCurrentValue(); - Value rhs = op.value(); + Value rhs = op.getValue(); Value cmp = bodyBuilder.create(loc, predicate, lhs, rhs); Value select = bodyBuilder.create(loc, cmp, lhs, rhs); bodyBuilder.create(loc, select); @@ -126,8 +126,8 @@ LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final { - Value lhs = op.lhs(); - Value rhs = op.rhs(); + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); Location loc = op.getLoc(); Value cmp = rewriter.create(loc, pred, lhs, rhs); @@ -153,8 +153,8 @@ using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final { - Value lhs = op.lhs(); - Value rhs = op.rhs(); + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); Location loc = op.getLoc(); Value cmp = rewriter.create(loc, pred, lhs, rhs); @@ -177,8 +177,8 @@ StandardOpsDialect>(); target.addIllegalOp(); target.addDynamicallyLegalOp([](AtomicRMWOp op) { - return op.kind() != AtomicRMWKind::maxf && - op.kind() != AtomicRMWKind::minf; + return op.getKind() != AtomicRMWKind::maxf && + op.getKind() != AtomicRMWKind::minf; }); target.addDynamicallyLegalOp([](memref::ReshapeOp op) { return !op.shape().getType().cast().hasStaticShape(); diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp @@ -31,7 +31,7 @@ // Substitute with the new result types from the corresponding FuncType // conversion. rewriter.replaceOpWithNewOp( - callOp, callOp.callee(), convertedResults, adaptor.getOperands()); + callOp, callOp.getCallee(), convertedResults, adaptor.getOperands()); return success(); } }; diff --git a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp @@ -28,7 +28,7 @@ // If we already have a global for this constant value, no need to do // anything else. - auto it = globals.find(constantOp.value()); + auto it = globals.find(constantOp.getValue()); if (it != globals.end()) return cast(it->second); @@ -52,14 +52,14 @@ constantOp.getLoc(), (Twine("__constant_") + os.str()).str(), /*sym_visibility=*/globalBuilder.getStringAttr("private"), /*type=*/typeConverter.convertType(type).cast(), - /*initial_value=*/constantOp.value().cast(), + /*initial_value=*/constantOp.getValue().cast(), /*constant=*/true, /*alignment=*/memrefAlignment); symbolTable.insert(global); // The symbol table inserts at the end of the module, but globals are a bit // nicer if they are at the beginning. global->moveBefore(&moduleOp.front()); - globals[constantOp.value()] = global; + globals[constantOp.getValue()] = global; return global; } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -218,7 +218,7 @@ Optional DimOp::getConstantIndex() { if (auto constantOp = index().getDefiningOp()) - return constantOp.value().cast().getInt(); + return constantOp.getValue().cast().getInt(); return {}; } diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -57,7 +57,7 @@ // Inspect constant dense values. We count up for bits that // are set, count down for bits that are cleared, and bail // when a mix is detected. - if (auto denseElts = c.value().dyn_cast()) { + if (auto denseElts = c.getValue().dyn_cast()) { int64_t val = 0; for (bool b : denseElts.getValues()) if (b && val >= 0) @@ -790,7 +790,7 @@ return vector::ContractionOp(); if (auto maybeZero = dyn_cast_or_null( contractionOp.acc().getDefiningOp())) { - if (maybeZero.value() == + if (maybeZero.getValue() == rewriter.getZeroAttr(contractionOp.acc().getType())) { BlockAndValueMapping bvm; bvm.map(contractionOp.acc(), otherOperand); @@ -2193,7 +2193,7 @@ extractStridedSliceOp.vector().getDefiningOp(); if (!constantOp) return failure(); - auto dense = constantOp.value().dyn_cast(); + auto dense = constantOp.getValue().dyn_cast(); if (!dense) return failure(); auto newAttr = DenseElementsAttr::get(extractStridedSliceOp.getType(), @@ -2270,7 +2270,7 @@ auto splat = op.vector().getDefiningOp(); if (!splat) return failure(); - rewriter.replaceOpWithNewOp(op, op.getType(), splat.input()); + rewriter.replaceOpWithNewOp(op, op.getType(), splat.getInput()); return success(); } }; @@ -3666,7 +3666,7 @@ if (!constantOp) return failure(); // Only handle splat for now. - auto dense = constantOp.value().dyn_cast(); + auto dense = constantOp.getValue().dyn_cast(); if (!dense) return failure(); auto newAttr = DenseElementsAttr::get( diff --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp --- a/mlir/lib/Dialect/Vector/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp @@ -325,14 +325,15 @@ if (i < rankOffset) { // For leading dimensions, if we can prove that index are different we // know we are accessing disjoint slices. - if (indexA.value().cast().getInt() != - indexB.value().cast().getInt()) + if (indexA.getValue().cast().getInt() != + indexB.getValue().cast().getInt()) return true; } else { // For this dimension, we slice a part of the memref we need to make sure // the intervals accessed don't overlap. - int64_t distance = std::abs(indexA.value().cast().getInt() - - indexB.value().cast().getInt()); + int64_t distance = + std::abs(indexA.getValue().cast().getInt() - + indexB.getValue().cast().getInt()); if (distance >= transferA.getVectorType().getDimSize(i - rankOffset)) return true; } diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -655,8 +655,24 @@ SmallVector names; bool rawToo = prefixType == Dialect::EmitPrefix::Both; + // Whether to skip generating prefixed form for argument. This just does some + // basic checks. + // + // There are a little bit more invasive checks possible for cases where not + // all ops have the trait that would cause overlap. For many cases here, + // renaming would be better (e.g., we can only guard in limited manner against + // methods from traits and interfaces here, so avoiding these in op definition + // is safer). auto skip = [&](StringRef newName) { - bool shouldSkip = newName == "getOperands"; + bool shouldSkip = newName == "getAttributeNames" || + newName == "getAttributes" || newName == "getOperation" || + newName == "getType"; + if (newName == "getOperands") { + // To reduce noise, skip generating the prefixed form and the warning if + // $operands correspond to single variadic argument. + if (op.hasSingleVariadicArg()) + return true; + } if (!shouldSkip) return false; @@ -677,11 +693,11 @@ if (skip(names.back())) { rawToo = true; names.clear(); - } else { + } else if (rawToo) { LLVM_DEBUG(llvm::errs() << "WITH_GETTER(\"" << op.getQualCppClassName() - << "::" << names.back() << "\");\n" + << "::" << name << "\");\n" << "WITH_GETTER(\"" << op.getQualCppClassName() - << "Adaptor::" << names.back() << "\");\n";); + << "Adaptor::" << name << "\");\n";); } } diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -222,7 +222,7 @@ static LogicalResult printOperation(CppEmitter &emitter, arith::ConstantOp constantOp) { Operation *operation = constantOp.getOperation(); - Attribute value = constantOp.value(); + Attribute value = constantOp.getValue(); return printConstantOp(emitter, operation, value); } @@ -230,7 +230,7 @@ static LogicalResult printOperation(CppEmitter &emitter, mlir::ConstantOp constantOp) { Operation *operation = constantOp.getOperation(); - Attribute value = constantOp.value(); + Attribute value = constantOp.getValue(); return printConstantOp(emitter, operation, value); } diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -339,10 +339,10 @@ b.create(op.getLoc(), ArrayRef({v})); } if (GV->hasAtLeastLocalUnnamedAddr()) - op.unnamed_addrAttr(UnnamedAddrAttr::get( + op.setUnnamedAddrAttr(UnnamedAddrAttr::get( context, convertUnnamedAddrFromLLVM(GV->getUnnamedAddr()))); if (GV->hasSection()) - op.sectionAttr(b.getStringAttr(GV->getSection())); + op.setSectionAttr(b.getStringAttr(GV->getSection())); return globals[GV] = op; } diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -304,8 +304,8 @@ // TODO: refactor function type creation which usually occurs in std-LLVM // conversion. SmallVector operandTypes; - operandTypes.reserve(inlineAsmOp.operands().size()); - for (auto t : inlineAsmOp.operands().getTypes()) + operandTypes.reserve(inlineAsmOp.getOperands().size()); + for (auto t : inlineAsmOp.getOperands().getTypes()) operandTypes.push_back(t); Type resultType; @@ -330,7 +330,8 @@ inlineAsmOp.asm_string(), inlineAsmOp.constraints(), inlineAsmOp.has_side_effects(), inlineAsmOp.is_align_stack()); llvm::Value *result = builder.CreateCall( - inlineAsmInst, moduleTranslation.lookupValues(inlineAsmOp.operands())); + inlineAsmInst, + moduleTranslation.lookupValues(inlineAsmOp.getOperands())); if (opInst.getNumResults() != 0) moduleTranslation.mapValue(opInst.getResult(0), result); return success(); @@ -383,7 +384,7 @@ return success(); } if (auto condbrOp = dyn_cast(opInst)) { - auto weights = condbrOp.branch_weights(); + auto weights = condbrOp.getBranchWeights(); llvm::MDNode *branchWeights = nullptr; if (weights) { // Map weight attributes to LLVM metadata. diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -366,8 +366,8 @@ // For conditional branches, we take the operands from either the "true" or // the "false" branch. return condBranchOp.getSuccessor(0) == current - ? condBranchOp.trueDestOperands()[index] - : condBranchOp.falseDestOperands()[index]; + ? condBranchOp.getTrueDestOperands()[index] + : condBranchOp.getFalseDestOperands()[index]; } if (auto switchOp = dyn_cast(terminator)) { @@ -574,8 +574,8 @@ } } - auto linkage = convertLinkageToLLVM(op.linkage()); - auto addrSpace = op.addr_space(); + auto linkage = convertLinkageToLLVM(op.getLinkage()); + auto addrSpace = op.getAddrSpace(); // LLVM IR requires constant with linkage other than external or weak // external to have initializers. If MLIR does not provide an initializer, @@ -587,18 +587,18 @@ cst = nullptr; auto *var = new llvm::GlobalVariable( - *llvmModule, type, op.constant(), linkage, cst, op.sym_name(), + *llvmModule, type, op.getConstant(), linkage, cst, op.getSymName(), /*InsertBefore=*/nullptr, llvm::GlobalValue::NotThreadLocal, addrSpace); - if (op.unnamed_addr().hasValue()) - var->setUnnamedAddr(convertUnnamedAddrToLLVM(*op.unnamed_addr())); + if (op.getUnnamedAddr().hasValue()) + var->setUnnamedAddr(convertUnnamedAddrToLLVM(*op.getUnnamedAddr())); - if (op.section().hasValue()) - var->setSection(*op.section()); + if (op.getSection().hasValue()) + var->setSection(*op.getSection()); - addRuntimePreemptionSpecifier(op.dso_local(), var); + addRuntimePreemptionSpecifier(op.getDsoLocal(), var); - Optional alignment = op.alignment(); + Optional alignment = op.getAlignment(); if (alignment.hasValue()) var->setAlignment(llvm::MaybeAlign(alignment.getValue())); @@ -895,7 +895,7 @@ llvm::LLVMContext &ctx = llvmModule->getContext(); llvm::SmallVector operands; operands.push_back({}); // Placeholder for self-reference - if (Optional description = op.description()) + if (Optional description = op.getDescription()) operands.push_back(llvm::MDString::get(ctx, description.getValue())); llvm::MDNode *domain = llvm::MDNode::get(ctx, operands); domain->replaceOperandWith(0, domain); // Self-reference for uniqueness @@ -908,13 +908,13 @@ assert(isa(op->getParentOp())); auto metadataOp = dyn_cast(op->getParentOp()); Operation *domainOp = - SymbolTable::lookupNearestSymbolFrom(metadataOp, op.domainAttr()); + SymbolTable::lookupNearestSymbolFrom(metadataOp, op.getDomainAttr()); llvm::MDNode *domain = aliasScopeDomainMetadataMapping.lookup(domainOp); assert(domain && "Scope's domain should already be valid"); llvm::SmallVector operands; operands.push_back({}); // Placeholder for self-reference operands.push_back(domain); - if (Optional description = op.description()) + if (Optional description = op.getDescription()) operands.push_back(llvm::MDString::get(ctx, description.getValue())); llvm::MDNode *scope = llvm::MDNode::get(ctx, operands); scope->replaceOperandWith(0, scope); // Self-reference for uniqueness diff --git a/mlir/lib/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Transforms/BufferResultsToOutParams.cpp --- a/mlir/lib/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Transforms/BufferResultsToOutParams.cpp @@ -109,7 +109,7 @@ newOperands.append(outParams.begin(), outParams.end()); auto newResultTypes = llvm::to_vector<6>(llvm::map_range( replaceWithNewCallResults, [](Value v) { return v.getType(); })); - auto newCall = builder.create(op.getLoc(), op.calleeAttr(), + auto newCall = builder.create(op.getLoc(), op.getCalleeAttr(), newResultTypes, newOperands); for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults())) std::get<0>(t).replaceAllUsesWith(std::get<1>(t)); diff --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp --- a/mlir/lib/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp @@ -129,7 +129,7 @@ // Functions called by this function. funcOp.walk([&](CallOp callOp) { - StringAttr callee = callOp.getCalleeAttr(); + StringAttr callee = callOp.getCalleeAttr().getAttr(); for (FuncOp &funcOp : normalizableFuncs) { // We compare FuncOp and callee's name. if (callee == funcOp.getNameAttr()) { diff --git a/mlir/test/Dialect/LLVMIR/global.mlir b/mlir/test/Dialect/LLVMIR/global.mlir --- a/mlir/test/Dialect/LLVMIR/global.mlir +++ b/mlir/test/Dialect/LLVMIR/global.mlir @@ -85,7 +85,7 @@ // ----- -// expected-error @+1 {{requires attribute 'type'}} +// expected-error @+1 {{requires attribute 'global_type'}} "llvm.mlir.global"() ({}) {sym_name = "foo", constant, value = 42 : i64} : () -> () // ----- @@ -96,12 +96,12 @@ // ----- // expected-error @+1 {{'addr_space' failed to satisfy constraint: 32-bit signless integer attribute whose value is non-negative}} -"llvm.mlir.global"() ({}) {sym_name = "foo", type = i64, value = 42 : i64, addr_space = -1 : i32, linkage = #llvm.linkage} : () -> () +"llvm.mlir.global"() ({}) {sym_name = "foo", global_type = i64, value = 42 : i64, addr_space = -1 : i32, linkage = #llvm.linkage} : () -> () // ----- // expected-error @+1 {{'addr_space' failed to satisfy constraint: 32-bit signless integer attribute whose value is non-negative}} -"llvm.mlir.global"() ({}) {sym_name = "foo", type = i64, value = 42 : i64, addr_space = 1.0 : f32, linkage = #llvm.linkage} : () -> () +"llvm.mlir.global"() ({}) {sym_name = "foo", global_type = i64, value = 42 : i64, addr_space = 1.0 : f32, linkage = #llvm.linkage} : () -> () // ----- diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -1,7 +1,7 @@ // RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s // CHECK: @global_aligned32 = private global i64 42, align 32 -"llvm.mlir.global"() ({}) {sym_name = "global_aligned32", type = i64, value = 42 : i64, linkage = #llvm.linkage, alignment = 32} : () -> () +"llvm.mlir.global"() ({}) {sym_name = "global_aligned32", global_type = i64, value = 42 : i64, linkage = #llvm.linkage, alignment = 32} : () -> () // CHECK: @global_aligned64 = private global i64 42, align 64 llvm.mlir.global private @global_aligned64(42 : i64) {alignment = 64 : i64} : i64