diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h --- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h +++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h @@ -121,7 +121,7 @@ const APFloat &rhs); /// Returns the identity value attribute associated with an AtomicRMWKind op. -Attribute getIdentityValueAttr(AtomicRMWKind kind, Type resultType, +TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc); /// Returns the identity value associated with an AtomicRMWKind op. diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h --- a/mlir/include/mlir/Dialect/CommonFolders.h +++ b/mlir/include/mlir/Dialect/CommonFolders.h @@ -64,7 +64,7 @@ if (!elementResult) return {}; - return DenseElementsAttr::get(resultType, *elementResult); + return DenseElementsAttr::get(cast(resultType), *elementResult); } if (operands[0].isa() && operands[1].isa()) { @@ -86,7 +86,7 @@ elementResults.push_back(*elementResult); } - return DenseElementsAttr::get(resultType, elementResults); + return DenseElementsAttr::get(cast(resultType), elementResults); } return {}; } @@ -233,7 +233,7 @@ calculate(op.getSplatValue(), castStatus); if (!castStatus) return {}; - return DenseElementsAttr::get(resType, elementResult); + return DenseElementsAttr::get(cast(resType), elementResult); } if (operands[0].isa()) { // Operand is ElementsAttr-derived; perform an element-wise fold by @@ -250,7 +250,7 @@ elementResults.push_back(elt); } - return DenseElementsAttr::get(resType, elementResults); + return DenseElementsAttr::get(cast(resType), elementResults); } return {}; } diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -129,7 +129,7 @@ /// Return the identity numeric value associated to the give op. Return /// std::nullopt if there is no known neutral element. -std::optional getNeutralElement(Operation *op); +std::optional getNeutralElement(Operation *op); //===----------------------------------------------------------------------===// // Fusion / Tiling utilities diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -1618,7 +1618,7 @@ if (getAtomicReductionRegion().empty()) return {}; - return getAtomicReductionRegion().front().getArgument(0).getType(); + return cast(getAtomicReductionRegion().front().getArgument(0).getType()); } }]; let hasRegionVerifier = 1; diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -20,6 +20,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Location.h" +#include "mlir/IR/TypeRange.h" #include "mlir/Support/LLVM.h" // Pull in all enum type definitions and utility function declarations. @@ -28,8 +29,6 @@ namespace mlir { class OpBuilder; -class TypeRange; -class ValueRange; class RewriterBase; /// Tests whether the given maps describe a row major matmul. The test is @@ -116,6 +115,11 @@ // Note: this is a true builder that notifies the OpBuilder listener. Operation *clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands); +template +OpT clone(OpBuilder &b, OpT op, TypeRange newResultTypes, + ValueRange newOperands) { + return cast(clone(b, op.getOperation(), newResultTypes, newOperands)); +} // Clone the current operation with the operands but leave the regions empty. // Note: this is a true builder that notifies the OpBuilder listener. diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -116,7 +116,7 @@ // Returns a 0-valued attribute of the given `type`. This function only // supports boolean, integer, and 16-/32-/64-bit float types, and vector or // ranked tensor of them. Returns null attribute otherwise. - Attribute getZeroAttr(Type type); + TypedAttr getZeroAttr(Type type); // Convenience methods for fixed types. FloatAttr getF16FloatAttr(float value); diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -78,9 +78,9 @@ using Attribute::Attribute; /// Allow implicit conversion to ElementsAttr. - operator ElementsAttr() const { - return *this ? cast() : nullptr; - } + operator ElementsAttr() const { return cast_if_present(*this); } + /// Allow implicit conversion to TypedAttr. + operator TypedAttr() const { return ElementsAttr(*this); } /// Type trait used to check if the given type T is a potentially valid C++ /// floating point type that can be used to access the underlying element @@ -842,9 +842,10 @@ static BoolAttr get(MLIRContext *context, bool value); - /// Enable conversion to IntegerAttr. This uses conversion vs. inheritance to - /// avoid bringing in all of IntegerAttrs methods. + /// Enable conversion to IntegerAttr and its interfaces. This uses conversion + /// vs. inheritance to avoid bringing in all of IntegerAttrs methods. operator IntegerAttr() const { return IntegerAttr(impl); } + operator TypedAttr() const { return IntegerAttr(impl); } /// Return the boolean value of this attribute. bool getValue() const; diff --git a/mlir/include/mlir/IR/FunctionInterfaces.td b/mlir/include/mlir/IR/FunctionInterfaces.td --- a/mlir/include/mlir/IR/FunctionInterfaces.td +++ b/mlir/include/mlir/IR/FunctionInterfaces.td @@ -275,7 +275,7 @@ /// has less parameters we drop the extra attributes, if there are more /// parameters they won't have any attributes. void setType(Type newType) { - function_interface_impl::setFunctionType(this->getOperation(), newType); + function_interface_impl::setFunctionType($_op, newType); } //===------------------------------------------------------------------===// @@ -316,7 +316,7 @@ Type newType = $_op.getTypeWithArgsAndResults( argIndices, argTypes, /*resultIndices=*/{}, /*resultTypes=*/{}); function_interface_impl::insertFunctionArguments( - this->getOperation(), argIndices, argTypes, argAttrs, argLocs, + $_op, argIndices, argTypes, argAttrs, argLocs, originalNumArgs, newType); } @@ -336,7 +336,7 @@ Type newType = $_op.getTypeWithArgsAndResults( /*argIndices=*/{}, /*argTypes=*/{}, resultIndices, resultTypes); function_interface_impl::insertFunctionResults( - this->getOperation(), resultIndices, resultTypes, resultAttrs, + $_op, resultIndices, resultTypes, resultAttrs, originalNumResults, newType); } @@ -351,7 +351,7 @@ void eraseArguments(const BitVector &argIndices) { Type newType = $_op.getTypeWithoutArgs(argIndices); function_interface_impl::eraseFunctionArguments( - this->getOperation(), argIndices, newType); + $_op, argIndices, newType); } /// Erase a single result at `resultIndex`. @@ -365,7 +365,7 @@ void eraseResults(const BitVector &resultIndices) { Type newType = $_op.getTypeWithoutResults(resultIndices); function_interface_impl::eraseFunctionResults( - this->getOperation(), resultIndices, newType); + $_op, resultIndices, newType); } /// Return the type of this function with the specified arguments and @@ -414,7 +414,7 @@ /// Return all of the attributes for the argument at 'index'. ArrayRef getArgAttrs(unsigned index) { - return function_interface_impl::getArgAttrs(this->getOperation(), index); + return function_interface_impl::getArgAttrs($_op, index); } /// Return an ArrayAttr containing all argument attribute dictionaries of @@ -464,11 +464,11 @@ } void setAllArgAttrs(ArrayRef attributes) { assert(attributes.size() == $_op.getNumArguments()); - function_interface_impl::setAllArgAttrDicts(this->getOperation(), attributes); + function_interface_impl::setAllArgAttrDicts($_op, attributes); } void setAllArgAttrs(ArrayRef attributes) { assert(attributes.size() == $_op.getNumArguments()); - function_interface_impl::setAllArgAttrDicts(this->getOperation(), attributes); + function_interface_impl::setAllArgAttrDicts($_op, attributes); } void setAllArgAttrs(ArrayAttr attributes) { assert(attributes.size() == $_op.getNumArguments()); @@ -503,7 +503,7 @@ /// Return all of the attributes for the result at 'index'. ArrayRef getResultAttrs(unsigned index) { - return function_interface_impl::getResultAttrs(this->getOperation(), index); + return function_interface_impl::getResultAttrs($_op, index); } /// Return an ArrayAttr containing all result attribute dictionaries of this @@ -554,12 +554,12 @@ void setAllResultAttrs(ArrayRef attributes) { assert(attributes.size() == $_op.getNumResults()); function_interface_impl::setAllResultAttrDicts( - this->getOperation(), attributes); + $_op, attributes); } void setAllResultAttrs(ArrayRef attributes) { assert(attributes.size() == $_op.getNumResults()); function_interface_impl::setAllResultAttrDicts( - this->getOperation(), attributes); + $_op, attributes); } void setAllResultAttrs(ArrayAttr attributes) { assert(attributes.size() == $_op.getNumResults()); @@ -589,7 +589,7 @@ /// attribute is returned. DictionaryAttr getArgAttrDict(unsigned index) { assert(index < $_op.getNumArguments() && "invalid argument number"); - return function_interface_impl::getArgAttrDict(this->getOperation(), index); + return function_interface_impl::getArgAttrDict($_op, index); } /// Returns the dictionary attribute corresponding to the result at 'index'. @@ -597,7 +597,7 @@ /// returned. DictionaryAttr getResultAttrDict(unsigned index) { assert(index < $_op.getNumResults() && "invalid result number"); - return function_interface_impl::getResultAttrDict(this->getOperation(), index); + return function_interface_impl::getResultAttrDict($_op, index); } }]; diff --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h --- a/mlir/include/mlir/Support/InterfaceSupport.h +++ b/mlir/include/mlir/Support/InterfaceSupport.h @@ -91,7 +91,7 @@ }; /// Construct an interface from an instance of the value type. - Interface(ValueT t = ValueT()) + explicit Interface(ValueT t = ValueT()) : BaseType(t), conceptImpl(t ? ConcreteType::getInterfaceFor(t) : nullptr) { assert((!t || conceptImpl) && diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -296,22 +296,22 @@ // on operands zero-extended to i(2*N) bits, and truncate the results back to // iN types. if (!resultType.isa()) { - Type wideType; // Shift amount necessary to extract the high bits from widened result. - Attribute shiftValAttr; + TypedAttr shiftValAttr; if (auto intTy = resultType.dyn_cast()) { unsigned resultBitwidth = intTy.getWidth(); - wideType = rewriter.getIntegerType(resultBitwidth * 2); - shiftValAttr = rewriter.getIntegerAttr(wideType, resultBitwidth); + auto attrTy = rewriter.getIntegerType(resultBitwidth * 2); + shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth); } else { auto vecTy = resultType.cast(); unsigned resultBitwidth = vecTy.getElementTypeBitWidth(); - wideType = VectorType::get(vecTy.getShape(), - rewriter.getIntegerType(resultBitwidth * 2)); + auto attrTy = VectorType::get( + vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2)); shiftValAttr = SplatElementsAttr::get( - wideType, APInt(resultBitwidth * 2, resultBitwidth)); + attrTy, APInt(resultBitwidth * 2, resultBitwidth)); } + Type wideType = shiftValAttr.getType(); assert(LLVM::isCompatibleType(wideType) && "LLVM dialect should support all signless integer types"); diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp --- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp +++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp @@ -40,7 +40,7 @@ return element; } -Attribute getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) { +TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) { if (auto shapedTy = type.dyn_cast()) { Type eTy = shapedTy.getElementType(); APInt valueInt(eTy.getIntOrFloatBitWidth(), value); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -625,7 +625,7 @@ // Returns the constant initial value for a given reduction operation. The // attribute type varies depending on the element type required. -static Attribute createInitialValueForReduceOp(Operation *op, Type elementTy, +static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter) { if (isa(op) && elementTy.isa()) return rewriter.getFloatAttr(elementTy, 0.0); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -31,7 +31,7 @@ using namespace mlir::tosa; static mlir::Value applyPad(Location loc, Value input, ArrayRef pad, - Attribute padAttr, OpBuilder &rewriter) { + TypedAttr padAttr, OpBuilder &rewriter) { // Input should be padded if necessary. if (llvm::all_of(pad, [](int64_t p) { return p == 0; })) return input; @@ -224,7 +224,7 @@ auto weightShape = weightTy.getShape(); // Apply padding as necessary. - Attribute zeroAttr = rewriter.getZeroAttr(inputETy); + TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy); if (isQuantized) { auto quantizationInfo = *op.getQuantizationInfo(); int64_t iZp = quantizationInfo.getInputZp(); @@ -269,7 +269,7 @@ weight = rewriter.create(loc, newWeightTy, weight, weightPermValue); - Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy); + auto resultZeroAttr = rewriter.getZeroAttr(resultETy); Value emptyTensor = rewriter.create( loc, resultTy.getShape(), resultETy, filteredDims); Value zero = rewriter.create(loc, resultZeroAttr); @@ -391,7 +391,7 @@ auto resultShape = resultTy.getShape(); // Apply padding as necessary. - Attribute zeroAttr = rewriter.getZeroAttr(inputETy); + TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy); if (isQuantized) { auto quantizationInfo = op->getAttr("quantization_info").cast(); @@ -439,7 +439,7 @@ indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); - Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy); + auto resultZeroAttr = rewriter.getZeroAttr(resultETy); Value emptyTensor = rewriter.create( loc, linalgConvTy.getShape(), resultETy, filteredDims); Value zero = rewriter.create(loc, resultZeroAttr); @@ -604,7 +604,7 @@ loc, outputTy.getShape(), outputTy.getElementType(), filteredDims); // When quantized, the input elemeny type is not the same as the output - Attribute resultZeroAttr = rewriter.getZeroAttr(outputETy); + auto resultZeroAttr = rewriter.getZeroAttr(outputETy); Value zero = rewriter.create(loc, resultZeroAttr); Value zeroTensor = rewriter .create(loc, ValueRange{zero}, @@ -688,7 +688,7 @@ SmallVector dynamicDims = *dynamicDimsOr; // Determine what the initial value needs to be for the max pool op. - Attribute initialAttr; + TypedAttr initialAttr; if (resultETy.isF32()) initialAttr = rewriter.getFloatAttr( resultETy, @@ -768,10 +768,10 @@ pad.resize(2, 0); llvm::append_range(pad, op.getPad()); pad.resize(pad.size() + 2, 0); - Attribute padAttr = rewriter.getZeroAttr(inElementTy); + TypedAttr padAttr = rewriter.getZeroAttr(inElementTy); Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter); - Attribute initialAttr = rewriter.getZeroAttr(accETy); + auto initialAttr = rewriter.getZeroAttr(accETy); Value initialValue = rewriter.create(loc, initialAttr); ArrayRef kernel = op.getKernel(); diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -296,7 +296,7 @@ padConstant = rewriter.createOrFold( loc, padOp.getPadConst(), ValueRange({})); } else { - Attribute constantAttr; + TypedAttr constantAttr; if (elementTy.isa()) { constantAttr = rewriter.getFloatAttr(elementTy, 0.0); } else if (elementTy.isa() && !padOp.getQuantizationInfo()) { diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -1674,8 +1674,7 @@ dynIdx++; } else { // Create ConstantOp for static dimension. - Attribute constantAttr = - b.getIntegerAttr(b.getIndexType(), oldMemRefShape[d]); + auto constantAttr = b.getIntegerAttr(b.getIndexType(), oldMemRefShape[d]); inAffineApply.emplace_back( b.create(allocOp->getLoc(), constantAttr)); } diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -21,6 +21,8 @@ // Subtract two integer attributes and createa a new one with the result. def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">; +class cast : NativeCodeCall<"::mlir::cast<" # type # ">($0)">; + //===----------------------------------------------------------------------===// // AddIOp //===----------------------------------------------------------------------===// @@ -320,8 +322,8 @@ // trunci(shrsi(x, c)) -> trunci(shrui(x, c)) def TruncIShrSIToTrunciShrUI : Pat<(Arith_TruncIOp:$tr - (Arith_ShRSIOp $x, (ConstantLikeMatcher AnyAttr:$c0))), - (Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp $c0))), + (Arith_ShRSIOp $x, (ConstantLikeMatcher TypedAttrInterface:$c0))), + (Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0)))), [(TruncationMatchesShiftAmount $x, $tr, $c0)]>; // trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y) diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -2313,7 +2313,7 @@ //===----------------------------------------------------------------------===// /// Returns the identity value attribute associated with an AtomicRMWKind op. -Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType, +TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc) { switch (kind) { case AtomicRMWKind::maxf: @@ -2362,7 +2362,7 @@ /// Returns the identity value associated with an AtomicRMWKind op. Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc) { - Attribute attr = getIdentityValueAttr(op, resultType, builder, loc); + auto attr = getIdentityValueAttr(op, resultType, builder, loc); return builder.create(loc, attr); } diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp --- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp @@ -61,7 +61,7 @@ static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter, Location loc, Type type, const APInt &value) { - Attribute attr; + TypedAttr attr; if (auto intTy = type.dyn_cast()) { attr = rewriter.getIntegerAttr(type, value); } else { @@ -1003,7 +1003,7 @@ Value hiFp = rewriter.create(loc, resultTy, hiInt); int64_t pow2Int = int64_t(1) << newBitWidth; - Attribute pow2Attr = + TypedAttr pow2Attr = rewriter.getFloatAttr(resultElemTy, static_cast(pow2Int)); if (auto vecTy = dyn_cast(resultTy)) pow2Attr = SplatElementsAttr::get(vecTy, pow2Attr); diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -121,7 +121,7 @@ if (getValueAttr().isa()) return success(); - TypedAttr value = getValueAttr(); + auto value = cast(getValueAttr()); Type type = getType(); if (!value.getType().isa() && type != value.getType()) return emitOpError() << "requires attribute's type (" << value.getType() @@ -177,7 +177,7 @@ if (getValueAttr().isa()) return success(); - TypedAttr value = getValueAttr(); + auto value = cast(getValueAttr()); Type type = getType(); if (!value.getType().isa() && type != value.getType()) return emitOpError() << "requires attribute's type (" << value.getType() diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -490,9 +490,9 @@ llvm::dbgs() << "\n"); // Step 2. sort the values by the corresponding DeviceMappingAttrInterface. - auto comparator = [&](DeviceMappingAttrInterface a, - DeviceMappingAttrInterface b) -> bool { - return a.getMappingId() < b.getMappingId(); + auto comparator = [&](Attribute a, Attribute b) -> bool { + return cast(a).getMappingId() < + cast(b).getMappingId(); }; SmallVector forallMappingSizes = getValuesSortedByKey(forallMappingAttrs, tmpMappingSizes, comparator); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -389,10 +389,7 @@ OpBuilder builder = getBuilder(); Location loc = builder.getUnknownLoc(); Attribute valueAttr = parseAttribute(value, builder.getContext()); - Type type = NoneType::get(builder.getContext()); - if (auto typedAttr = valueAttr.dyn_cast()) - type = typedAttr.getType(); - return builder.create(loc, type, valueAttr); + return builder.create(loc, ::cast(valueAttr)); } Value index(int64_t dim) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -165,7 +165,7 @@ staticStridesVector)); } - Operation *clonedOp = clone(b, producer, resultTypes, clonedShapes); + LinalgOp clonedOp = clone(b, producer, resultTypes, clonedShapes); // Shift all IndexOp results by the tile offset. SmallVector allIvs = llvm::to_vector( diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -38,7 +38,7 @@ static SmallVector getTiledSliceDims(OpOperand *consumerOperand, ArrayRef tiledLoopDims) { // Get the consumer operand indexing map. - LinalgOp consumerOp = consumerOperand->getOwner(); + auto consumerOp = cast(consumerOperand->getOwner()); AffineMap indexingMap = consumerOp.getMatchingIndexingMap(consumerOperand); // Search the slice dimensions tiled by a tile loop dimension. @@ -65,7 +65,7 @@ static SmallVector getTiledProducerLoops(OpResult producerResult, ArrayRef tiledSliceDimIndices) { - LinalgOp producerOp = producerResult.getOwner(); + auto producerOp = cast(producerResult.getOwner()); // Get the indexing map of the `producerOp` output operand that matches // ´producerResult´. @@ -137,7 +137,7 @@ b.setInsertionPointAfter(sliceOp); // Get the producer. - LinalgOp producerOp = producerResult.getOwner(); + auto producerOp = cast(producerResult.getOwner()); Location loc = producerOp.getLoc(); // Obtain the `producerOp` loop bounds and the `sliceOp` ranges. @@ -345,7 +345,7 @@ return failure(); // Check `sliceOp` and `consumerOp` are in the same block. - LinalgOp consumerOp = consumerOpOperand->getOwner(); + auto consumerOp = cast(consumerOpOperand->getOwner()); if (sliceOp->getBlock() != rootOp->getBlock() || consumerOp->getBlock() != rootOp->getBlock()) return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -198,7 +198,7 @@ "expected the number of loops and induction variables to match"); // Replace the index operations in the body of the innermost loop op. if (!loopOps.empty()) { - LoopLikeOpInterface loopOp = loopOps.back(); + auto loopOp = cast(loopOps.back()); for (IndexOp indexOp : llvm::make_early_inc_range(loopOp.getLoopBody().getOps())) rewriter.replaceOp(indexOp, allIvs[indexOp.getDim()]); diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -66,7 +66,7 @@ return b.notifyMatchFailure(op, "Cannot match the reduction pattern"); Operation *reductionOp = combinerOps[0]; - std::optional identity = getNeutralElement(reductionOp); + std::optional identity = getNeutralElement(reductionOp); if (!identity.has_value()) return b.notifyMatchFailure(op, "Unknown identity value for the reduction"); @@ -272,9 +272,9 @@ if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps)) return b.notifyMatchFailure(op, "cannot match a reduction pattern"); - SmallVector neutralElements; + SmallVector neutralElements; for (Operation *reductionOp : combinerOps) { - std::optional neutralElement = getNeutralElement(reductionOp); + std::optional neutralElement = getNeutralElement(reductionOp); if (!neutralElement.has_value()) return b.notifyMatchFailure(op, "cannot find neutral element."); neutralElements.push_back(*neutralElement); diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -271,7 +271,7 @@ return op->emitOpError("Failed to anaysis the reduction operation."); Operation *reductionOp = combinerOps[0]; - std::optional identity = getNeutralElement(reductionOp); + std::optional identity = getNeutralElement(reductionOp); if (!identity.has_value()) return op->emitOpError( "Failed to get an identity value for the reduction operation."); @@ -328,8 +328,8 @@ // Step 1: Extract a slice of the input operands. SmallVector valuesToTile = linalgOp.getDpsInputOperands(); - SmallVector tiledOperands = - makeTiledShapes(b, loc, op, valuesToTile, offsets, sizes, {}, true); + SmallVector tiledOperands = makeTiledShapes( + b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true); // Step 2: Extract the accumulator operands SmallVector strides(offsets.size(), b.getIndexAttr(1)); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -83,11 +83,8 @@ return rewriter.notifyMatchFailure(opToPad, "--no padding value specified"); } Attribute paddingAttr = paddingValues[opOperand->getOperandNumber()]; - Type paddingType = rewriter.getType(); - if (auto typedAttr = paddingAttr.dyn_cast()) - paddingType = typedAttr.getType(); Value paddingValue = rewriter.create( - opToPad.getLoc(), paddingType, paddingAttr); + opToPad.getLoc(), cast(paddingAttr)); // Follow the use-def chain if `currOpOperand` is defined by a LinalgOp. OpOperand *currOpOperand = opOperand; @@ -576,7 +573,7 @@ rewriter, loc, operand, innerPackSizes, innerPos, /*outerDimsPerm=*/{}); // TODO: value of the padding attribute should be determined by consumers. - Attribute zeroAttr = + auto zeroAttr = rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType())); Value zero = rewriter.create(loc, zeroAttr); packOps.push_back(rewriter.create( diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -983,7 +983,7 @@ } /// Return the identity numeric value associated to the give op. -std::optional getNeutralElement(Operation *op) { +std::optional getNeutralElement(Operation *op) { // Builder only used as helper for attribute creation. OpBuilder b(op->getContext()); Type resultType = op->getResult(0).getType(); diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -1245,7 +1245,7 @@ floatTy = broadcast(floatTy, shape); intTy = broadcast(intTy, shape); - auto bconst = [&](Attribute attr) -> Value { + auto bconst = [&](TypedAttr attr) -> Value { Value value = b.create(attr); return broadcast(b, value, shape); }; diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp @@ -111,9 +111,9 @@ loc, rewriter.getIndexType(), size); sizes[i] = size; } else { - sizes[i] = rewriter.getIndexAttr(op.getType().getDimSize(i)); - size = - rewriter.create(loc, sizes[i].get()); + auto sizeAttr = rewriter.getIndexAttr(op.getType().getDimSize(i)); + size = rewriter.create(loc, sizeAttr); + sizes[i] = sizeAttr; } strides[i] = stride; if (i > 0) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp @@ -44,7 +44,7 @@ if (auto intTy = type.dyn_cast()) return IntegerAttr::get(intTy, sizedValue); - return SplatElementsAttr::get(type, sizedValue); + return SplatElementsAttr::get(cast(type), sizedValue); } Value lowerExtendedMultiplication(Operation *mulOp, PatternRewriter &rewriter, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -79,7 +79,7 @@ /// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`, /// for unsupported types we raise `llvm_unreachable` rather than /// returning a null attribute. -Attribute getOneAttr(Builder &builder, Type tp); +TypedAttr getOneAttr(Builder &builder, Type tp); /// Generates the comparison `v != 0` where `v` is of numeric type. /// For floating types, we use the "unordered" comparator (i.e., returns diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -213,7 +213,7 @@ return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast); } -mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) { +mlir::TypedAttr mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) { if (tp.isa()) return builder.getFloatAttr(tp, 1.0); if (tp.isa()) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1729,7 +1729,7 @@ auto splat = vectorCst.dyn_cast(); if (!splat) return failure(); - Attribute newAttr = splat.getSplatValue(); + TypedAttr newAttr = splat.getSplatValue(); if (auto vecDstType = extractOp.getType().dyn_cast()) newAttr = DenseElementsAttr::get(vecDstType, newAttr); rewriter.replaceOpWithNewOp(extractOp, newAttr); @@ -1767,9 +1767,9 @@ copy(getI64SubArray(extractOp.getPosition()), completePositions.begin()); int64_t elemBeginPosition = linearize(completePositions, computeStrides(vecTy.getShape())); - auto denseValuesBegin = dense.value_begin() + elemBeginPosition; + auto denseValuesBegin = dense.value_begin() + elemBeginPosition; - Attribute newAttr; + TypedAttr newAttr; if (auto resVecTy = extractOp.getType().dyn_cast()) { SmallVector elementValues( denseValuesBegin, denseValuesBegin + resVecTy.getNumElements()); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -191,7 +191,7 @@ private: LogicalResult matchAndRewrite(MaskOp maskOp, PatternRewriter &rewriter) const final { - MaskableOpInterface maskableOp = maskOp.getMaskableOp(); + auto maskableOp = cast(maskOp.getMaskableOp()); SourceOp sourceOp = dyn_cast(maskableOp.getOperation()); if (!sourceOp) return failure(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -692,8 +692,8 @@ return failure(); unsigned operandIndex = yieldOperand->getOperandNumber(); Attribute scalarAttr = dense.getSplatValue(); - Attribute newAttr = DenseElementsAttr::get( - warpOp.getResult(operandIndex).getType(), scalarAttr); + auto newAttr = DenseElementsAttr::get( + cast(warpOp.getResult(operandIndex).getType()), scalarAttr); Location loc = warpOp.getLoc(); rewriter.setInsertionPointAfter(warpOp); Value distConstant = rewriter.create(loc, newAttr); diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp --- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp @@ -79,7 +79,7 @@ src = rewriter.create(op.getLoc(), opType, op.getConstantSrcAttr()); } else { - Attribute zeroAttr = rewriter.getZeroAttr(opType); + auto zeroAttr = rewriter.getZeroAttr(opType); src = rewriter.create(op->getLoc(), opType, zeroAttr); } diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -315,7 +315,7 @@ return getArrayAttr(attrs); } -Attribute Builder::getZeroAttr(Type type) { +TypedAttr Builder::getZeroAttr(Type type) { if (type.isa()) return getFloatAttr(type, 0.0); if (type.isa()) diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -539,7 +539,7 @@ elementType.getContext()); // Wrap AffineMap into Attribute. - Attribute layout = AffineMapAttr::get(map); + auto layout = AffineMapAttr::get(map); // Drop default memory space value and replace it with empty attribute. memorySpace = skipDefaultMemorySpace(memorySpace); @@ -559,7 +559,7 @@ elementType.getContext()); // Wrap AffineMap into Attribute. - Attribute layout = AffineMapAttr::get(map); + auto layout = AffineMapAttr::get(map); // Drop default memory space value and replace it with empty attribute. memorySpace = skipDefaultMemorySpace(memorySpace); @@ -577,7 +577,7 @@ elementType.getContext()); // Wrap AffineMap into Attribute. - Attribute layout = AffineMapAttr::get(map); + auto layout = AffineMapAttr::get(map); // Convert deprecated integer-like memory space to Attribute. Attribute memorySpace = @@ -598,7 +598,7 @@ elementType.getContext()); // Wrap AffineMap into Attribute. - Attribute layout = AffineMapAttr::get(map); + auto layout = AffineMapAttr::get(map); // Convert deprecated integer-like memory space to Attribute. Attribute memorySpace = diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -220,7 +220,7 @@ /// Creates a spirv::SpecConstantOp. spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID, - Attribute defaultValue); + TypedAttr defaultValue); /// Processes the OpVariable instructions at current `offset` into `binary`. /// It is expected that this method is used for variables that are to be diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -535,7 +535,7 @@ spirv::SpecConstantOp spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID, - Attribute defaultValue) { + TypedAttr defaultValue) { auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID)); auto op = opBuilder.create(unknownLoc, symName, defaultValue); diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -222,7 +222,7 @@ Block::iterator inlinePoint, IRMapping &mapper, ValueRange resultsToReplace, TypeRange regionResultTypes, std::optional inlineLoc, - bool shouldCloneInlinedRegion, Operation *call = nullptr) { + bool shouldCloneInlinedRegion, CallOpInterface call = {}) { assert(resultsToReplace.size() == regionResultTypes.size()); // We expect the region to have at least one block. if (src->empty()) @@ -328,7 +328,7 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, Block::iterator inlinePoint, ValueRange inlinedOperands, ValueRange resultsToReplace, std::optional inlineLoc, - bool shouldCloneInlinedRegion, Operation *call = nullptr) { + bool shouldCloneInlinedRegion, CallOpInterface call = {}) { // We expect the region to have at least one block. if (src->empty()) return failure(); diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp @@ -50,7 +50,7 @@ auto *originalOp = info->originalProducer.getOperation(); auto *originalOpInLinalgOpsVector = std::find(linalgOps.begin(), linalgOps.end(), originalOp); - *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); + *originalOpInLinalgOpsVector = info->fusedProducer; // Don't mark for erasure in the tensor case, let DCE handle this. changed = true; }