diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -147,11 +147,11 @@ ConversionPatternRewriter &rewriter) const final { if constexpr (SourceOp::hasProperties()) rewrite(cast(op), - OpAdaptor(operands, op->getAttrDictionary(), + OpAdaptor(operands, op->getDiscardableAttrDictionary(), cast(op).getProperties()), rewriter); - rewrite(cast(op), OpAdaptor(operands, op->getAttrDictionary()), - rewriter); + rewrite(cast(op), + OpAdaptor(operands, op->getDiscardableAttrDictionary()), rewriter); } LogicalResult match(Operation *op) const final { return match(cast(op)); @@ -161,12 +161,13 @@ ConversionPatternRewriter &rewriter) const final { if constexpr (SourceOp::hasProperties()) return matchAndRewrite(cast(op), - OpAdaptor(operands, op->getAttrDictionary(), + OpAdaptor(operands, + op->getDiscardableAttrDictionary(), cast(op).getProperties()), rewriter); - return matchAndRewrite(cast(op), - OpAdaptor(operands, op->getAttrDictionary()), - rewriter); + return matchAndRewrite( + cast(op), + OpAdaptor(operands, op->getDiscardableAttrDictionary()), rewriter); } /// Rewrite and Match methods that operate on the SourceOp type. These must be diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1890,11 +1890,12 @@ if constexpr (has_fold_adaptor_single_result_v) { if constexpr (hasProperties()) { result = cast(op).fold(typename ConcreteOpT::FoldAdaptor( - operands, op->getAttrDictionary(), + operands, op->getDiscardableAttrDictionary(), cast(op).getProperties(), op->getRegions())); } else { result = cast(op).fold(typename ConcreteOpT::FoldAdaptor( - operands, op->getAttrDictionary(), {}, op->getRegions())); + operands, op->getDiscardableAttrDictionary(), {}, + op->getRegions())); } } else { result = cast(op).fold(operands); @@ -1920,13 +1921,14 @@ if constexpr (hasProperties()) { result = cast(op).fold( typename ConcreteOpT::FoldAdaptor( - operands, op->getAttrDictionary(), + operands, op->getDiscardableAttrDictionary(), cast(op).getProperties(), op->getRegions()), results); } else { result = cast(op).fold( - typename ConcreteOpT::FoldAdaptor(operands, op->getAttrDictionary(), - {}, op->getRegions()), + typename ConcreteOpT::FoldAdaptor( + operands, op->getDiscardableAttrDictionary(), {}, + op->getRegions()), results); } } else { diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -520,7 +520,10 @@ } void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - rewrite(cast(op), OpAdaptor(operands, op->getAttrDictionary()), + auto sourceOp = cast(op); + rewrite(sourceOp, + OpAdaptor(operands, op->getDiscardableAttrDictionary(), + sourceOp.getProperties()), rewriter); } LogicalResult @@ -529,11 +532,13 @@ auto sourceOp = cast(op); if constexpr (SourceOp::hasProperties()) return matchAndRewrite(sourceOp, - OpAdaptor(operands, op->getAttrDictionary(), + OpAdaptor(operands, + op->getDiscardableAttrDictionary(), sourceOp.getProperties()), rewriter); return matchAndRewrite( - sourceOp, OpAdaptor(operands, op->getAttrDictionary()), rewriter); + sourceOp, OpAdaptor(operands, op->getDiscardableAttrDictionary()), + rewriter); } /// Rewrite and Match methods that operate on the SourceOp type. These must be diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -131,8 +131,11 @@ LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - return matchAndRewrite(cast(op), - OpAdaptor(operands, op->getAttrDictionary()), + auto reallocOp = cast(op); + return matchAndRewrite(reallocOp, + OpAdaptor(operands, + op->getDiscardableAttrDictionary(), + reallocOp.getProperties()), rewriter); } diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp --- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp @@ -111,10 +111,10 @@ resultTypes.reserve(1 + op->getNumResults()); copy(op->getResultTypes(), std::back_inserter(resultTypes)); resultTypes.push_back(tokenType); - auto *newOp = Operation::create(op->getLoc(), op->getName(), resultTypes, - op->getOperands(), op->getAttrDictionary(), - op->getPropertiesStorage(), - op->getSuccessors(), op->getNumRegions()); + auto *newOp = Operation::create( + op->getLoc(), op->getName(), resultTypes, op->getOperands(), + op->getDiscardableAttrDictionary(), op->getPropertiesStorage(), + op->getSuccessors(), op->getNumRegions()); // Clone regions into new op. IRMapping mapping; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -393,7 +393,9 @@ } bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const { - return lhs->getAttrDictionary() == rhs->getAttrDictionary(); + return lhs->getDiscardableAttrDictionary() == + rhs->getDiscardableAttrDictionary() && + lhs->hashProperties() == rhs->hashProperties(); } // Returns a source value for the given block. diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -368,7 +368,8 @@ OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape = operands.getShape(0); - IntegerAttr axis = attributes.get("axis").cast(); + auto *prop = properties.as(); + IntegerAttr axis = prop->axis; int32_t axisVal = axis.getValue().getSExtValue(); if (!inputShape.hasRank()) { @@ -431,8 +432,8 @@ OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { // Infer all dimension sizes by reducing based on inputs. - int32_t axis = - attributes.get("axis").cast().getValue().getSExtValue(); + auto *prop = properties.as(); + int32_t axis = prop->axis.getValue().getSExtValue(); llvm::SmallVector outputShape; bool hasRankedInput = false; for (auto operand : operands) { @@ -969,7 +970,7 @@ Type inputType = \ operands.getType()[0].cast().getElementType(); \ return ReduceInferReturnTypes(operands.getShape(0), inputType, \ - attributes.get("axis").cast(), \ + properties.as()->axis, \ inferredReturnShapes); \ } \ COMPATIBLE_RETURN_TYPES(OP) @@ -1046,6 +1047,7 @@ static LogicalResult poolingInferReturnTypes( const ValueShapeRange &operands, DictionaryAttr attributes, + ArrayRef kernel, ArrayRef stride, ArrayRef pad, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape = operands.getShape(0); llvm::SmallVector outputShape; @@ -1064,10 +1066,6 @@ int64_t height = inputShape.getDimSize(1); int64_t width = inputShape.getDimSize(2); - ArrayRef kernel = attributes.get("kernel").cast(); - ArrayRef stride = attributes.get("stride").cast(); - ArrayRef pad = attributes.get("pad").cast(); - if (!ShapedType::isDynamic(height)) { int64_t padded = height + pad[0] + pad[1] - kernel[0]; outputShape[1] = padded / stride[0] + 1; @@ -1227,7 +1225,9 @@ ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { - return poolingInferReturnTypes(operands, attributes, inferredReturnShapes); + Properties &prop = *properties.as(); + return poolingInferReturnTypes(operands, attributes, prop.kernel, prop.stride, + prop.pad, inferredReturnShapes); } LogicalResult MaxPool2dOp::inferReturnTypeComponents( @@ -1235,7 +1235,9 @@ ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { - return poolingInferReturnTypes(operands, attributes, inferredReturnShapes); + Properties &prop = *properties.as(); + return poolingInferReturnTypes(operands, attributes, prop.kernel, prop.stride, + prop.pad, inferredReturnShapes); } LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents( diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -37,10 +37,10 @@ SmallVector returnedShapes; if (shapeInterface - .inferReturnTypeComponents(op.getContext(), op.getLoc(), - op->getOperands(), op->getAttrDictionary(), - op->getPropertiesStorage(), - op->getRegions(), returnedShapes) + .inferReturnTypeComponents( + op.getContext(), op.getLoc(), op->getOperands(), + op->getDiscardableAttrDictionary(), op->getPropertiesStorage(), + op->getRegions(), returnedShapes) .failed()) return op; diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp @@ -218,9 +218,10 @@ ValueShapeRange range(op.getOperands(), operandShape); if (shapeInterface - .inferReturnTypeComponents( - op.getContext(), op.getLoc(), range, op.getAttrDictionary(), - op.getPropertiesStorage(), op.getRegions(), returnedShapes) + .inferReturnTypeComponents(op.getContext(), op.getLoc(), range, + op.getDiscardableAttrDictionary(), + op.getPropertiesStorage(), + op.getRegions(), returnedShapes) .succeeded()) { for (auto it : llvm::zip(op.getResults(), returnedShapes)) { Value result = std::get<0>(it); diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -651,7 +651,7 @@ // - Attributes // - Result Types llvm::hash_code hash = - llvm::hash_combine(op->getName(), op->getAttrDictionary(), + llvm::hash_combine(op->getName(), op->getDiscardableAttrDictionary(), op->getResultTypes(), op->hashProperties()); // - Operands @@ -766,11 +766,13 @@ // 1. Compare the operation properties. if (lhs->getName() != rhs->getName() || - lhs->getAttrDictionary() != rhs->getAttrDictionary() || + lhs->getDiscardableAttrDictionary() != + rhs->getDiscardableAttrDictionary() || lhs->getNumRegions() != rhs->getNumRegions() || lhs->getNumSuccessors() != rhs->getNumSuccessors() || lhs->getNumOperands() != rhs->getNumOperands() || - lhs->getNumResults() != rhs->getNumResults()) + lhs->getNumResults() != rhs->getNumResults() || + lhs->hashProperties() != rhs->hashProperties()) return false; if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc()) return false; @@ -874,7 +876,9 @@ // - Operation pointer addDataToHash(hasher, op); // - Attributes - addDataToHash(hasher, op->getAttrDictionary()); + addDataToHash(hasher, op->getDiscardableAttrDictionary()); + // - Properties + addDataToHash(hasher, op->hashProperties()); // - Blocks in Regions for (Region ®ion : op->getRegions()) { for (Block &block : region) { diff --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp --- a/mlir/lib/IR/Verifier.cpp +++ b/mlir/lib/IR/Verifier.cpp @@ -174,7 +174,7 @@ return op.emitError("null operand found"); /// Verify that all of the attributes are okay. - for (auto attr : op.getAttrs()) { + for (auto attr : op.getDiscardableAttrDictionary()) { // Check for any optional dialect specific attributes. if (auto *dialect = attr.getNameDialect()) if (failed(dialect->verifyOperationAttribute(&op, attr))) diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -251,8 +251,8 @@ auto retTypeFn = cast(op); auto result = retTypeFn.refineReturnTypes( op->getContext(), op->getLoc(), op->getOperands(), - op->getAttrDictionary(), op->getPropertiesStorage(), op->getRegions(), - inferredReturnTypes); + op->getDiscardableAttrDictionary(), op->getPropertiesStorage(), + op->getRegions(), inferredReturnTypes); if (failed(result)) op->emitOpError() << "failed to infer returned types"; diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -433,7 +433,7 @@ std::array values = {{fop.getArgument(i), fop.getArgument(j)}}; SmallVector inferredReturnTypes; if (succeeded(OpTy::inferReturnTypes( - context, std::nullopt, values, op->getAttrDictionary(), + context, std::nullopt, values, op->getDiscardableAttrDictionary(), op->getPropertiesStorage(), op->getRegions(), inferredReturnTypes))) { OperationState state(location, OpTy::getOperationName()); diff --git a/mlir/test/lib/IR/TestOperationEquals.cpp b/mlir/test/lib/IR/TestOperationEquals.cpp --- a/mlir/test/lib/IR/TestOperationEquals.cpp +++ b/mlir/test/lib/IR/TestOperationEquals.cpp @@ -31,7 +31,7 @@ Operation *first = &module.getBody()->front(); llvm::outs() << first->getName().getStringRef() << " with attr " - << first->getAttrDictionary(); + << first->getDiscardableAttrDictionary(); OperationEquivalence::Flags flags{}; if (!first->hasAttr("strict_loc_check")) flags |= OperationEquivalence::IgnoreLocations;