Index: mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td =================================================================== --- mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1190,41 +1190,7 @@ [NoSideEffect, ViewLikeOpInterface])>, Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>, Results<(outs AnyStridedMemRef:$result)>{ - let builders = [ - // Builders for a contracting reshape whose result type is computed from - // `src` and `reassociation`. - OpBuilder<(ins "Value":$src, - "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs)>, - OpBuilder<(ins "Value":$src, - "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), - [{ - auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); - build($_builder, $_state, src, reassociationMaps, attrs); - }]>, - - // Builders for a reshape whose result type is passed explicitly. This may - // be either a contracting or expanding reshape. - OpBuilder<(ins "Type":$resultType, "Value":$src, - "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), - [{ - build($_builder, $_state, resultType, src, attrs); - $_state.addAttribute("reassociation", - getReassociationIndicesAttribute($_builder, reassociation)); - }]>, - OpBuilder<(ins "Type":$resultType, "Value":$src, - "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), - [{ - auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); - build($_builder, $_state, resultType, src, reassociationMaps, attrs); - }]> - ]; - + code commonExtraClassDeclaration = [{ SmallVector getReassociationMaps(); SmallVector getReassociationExprs(); @@ -1285,6 +1251,25 @@ memref into memref ``` }]; + let builders = [ + // Builders using ReassociationIndices. + OpBuilder<(ins "Type":$resultType, "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + build($_builder, $_state, resultType, src, attrs); + $_state.addAttribute("reassociation", + getReassociationIndicesAttribute($_builder, reassociation)); + }]>, + OpBuilder<(ins "Type":$resultType, "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + auto reassociationMaps = + convertReassociationMapsToIndices($_builder, reassociation); + build($_builder, $_state, resultType, src, reassociationMaps, attrs); + }]> + ]; let extraClassDeclaration = commonExtraClassDeclaration; let hasVerifier = 1; } @@ -1323,6 +1308,39 @@ memref into memref ``` }]; + let builders = [ + // Builders for a contracting reshape whose result type is computed from + // `src` and `reassociation`. + OpBuilder<(ins "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs)>, + OpBuilder<(ins "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + auto reassociationMaps = + convertReassociationMapsToIndices($_builder, reassociation); + build($_builder, $_state, src, reassociationMaps, attrs); + }]>, + + // Builders for a reshape whose result type is passed explicitly. + OpBuilder<(ins "Type":$resultType, "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + build($_builder, $_state, resultType, src, attrs); + $_state.addAttribute("reassociation", + getReassociationIndicesAttribute($_builder, reassociation)); + }]>, + OpBuilder<(ins "Type":$resultType, "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + auto reassociationMaps = + convertReassociationMapsToIndices($_builder, reassociation); + build($_builder, $_state, resultType, src, reassociationMaps, attrs); + }]> + ]; let extraClassDeclaration = commonExtraClassDeclaration; let hasVerifier = 1; } Index: mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td =================================================================== --- mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -677,41 +677,7 @@ Tensor_Op, Arguments<(ins AnyTensor:$src, IndexListArrayAttr:$reassociation)>, Results<(outs AnyTensor:$result)> { - let builders = [ - // Builders for a contracting reshape whose result type is computed from - // `src` and `reassociation`. - OpBuilder<(ins "Value":$src, - "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs)>, - OpBuilder<(ins "Value":$src, - "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), - [{ - auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); - build($_builder, $_state, src, reassociationMaps, attrs); - }]>, - - // Builders for a reshape whose result type is passed explicitly. This may - // be either a contracting or expanding reshape. - OpBuilder<(ins "Type":$resultType, "Value":$src, - "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), - [{ - build($_builder, $_state, resultType, src, attrs); - $_state.addAttribute("reassociation", - getReassociationIndicesAttribute($_builder, reassociation)); - }]>, - OpBuilder<(ins "Type":$resultType, "Value":$src, - "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), - [{ - auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); - build($_builder, $_state, resultType, src, reassociationMaps, attrs); - }]> - ]; - + code commonExtraClassDeclaration = [{ static StringRef getReassociationAttrName() { return "reassociation"; } SmallVector getReassociationMaps(); @@ -767,6 +733,26 @@ : tensor into tensor ``` }]; + let builders = [ + // Builders using ReassociationIndices. + OpBuilder<(ins "Type":$resultType, "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + build($_builder, $_state, resultType, src, attrs); + $_state.addAttribute("reassociation", + getReassociationIndicesAttribute($_builder, reassociation)); + }]>, + OpBuilder<(ins "Type":$resultType, "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + auto reassociationMaps = + convertReassociationMapsToIndices($_builder, reassociation); + build($_builder, $_state, resultType, src, reassociationMaps, attrs); + }]> + ]; + let extraClassDeclaration = commonExtraClassDeclaration; let hasVerifier = 1; } @@ -796,6 +782,40 @@ : tensor into tensor ``` }]; + let builders = [ + // Builders for a contracting reshape whose result type is computed from + // `src` and `reassociation`. + OpBuilder<(ins "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs)>, + OpBuilder<(ins "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + auto reassociationMaps = + convertReassociationMapsToIndices($_builder, reassociation); + build($_builder, $_state, src, reassociationMaps, attrs); + }]>, + + // Builders for a reshape whose result type is passed explicitly. + OpBuilder<(ins "Type":$resultType, "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + build($_builder, $_state, resultType, src, attrs); + $_state.addAttribute("reassociation", + getReassociationIndicesAttribute($_builder, reassociation)); + }]>, + OpBuilder<(ins "Type":$resultType, "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + auto reassociationMaps = + convertReassociationMapsToIndices($_builder, reassociation); + build($_builder, $_state, resultType, src, reassociationMaps, attrs); + }]> + ]; + let extraClassDeclaration = commonExtraClassDeclaration; let hasVerifier = 1; } Index: mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -2223,12 +2223,11 @@ void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns( RewritePatternSet &patterns) { - patterns - .add, - FoldProducerReshapeOpByLinearization, - FoldConsumerReshapeOpByLinearization, - FoldConsumerReshapeOpByLinearization>( - patterns.getContext()); + patterns.add< + FoldProducerReshapeOpByLinearization, + FoldProducerReshapeOpByLinearization, + FoldConsumerReshapeOpByLinearization>( + patterns.getContext()); } void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns( @@ -2236,8 +2235,7 @@ patterns .add, FoldProducerReshapeOpByLinearization, - FoldConsumerReshapeOpByLinearization, - FoldConsumerReshapeOpByLinearization>( + FoldConsumerReshapeOpByLinearization>( patterns.getContext()); } Index: mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp =================================================================== --- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1659,18 +1659,6 @@ AffineMapAttr::get(layout))); } -void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src, - ArrayRef reassociation, - ArrayRef attrs) { - auto memRefType = src.getType().cast(); - auto resultType = computeReshapeCollapsedType( - memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( - b.getContext(), reassociation))); - build(b, result, resultType, src, attrs); - result.addAttribute(getReassociationAttrName(), - getReassociationIndicesAttribute(b, reassociation)); -} - void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, ArrayRef reassociation, ArrayRef attrs) { Index: mlir/lib/Dialect/Tensor/IR/TensorOps.cpp =================================================================== --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -802,18 +802,6 @@ getReassociationIndicesAttribute(b, reassociation)); } -void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src, - ArrayRef reassociation, - ArrayRef attrs) { - auto resultType = computeTensorReshapeCollapsedType( - src.getType().cast(), - getSymbolLessAffineMaps( - convertReassociationIndicesToExprs(b.getContext(), reassociation))); - build(b, result, resultType, src, attrs); - result.addAttribute(getReassociationAttrName(), - getReassociationIndicesAttribute(b, reassociation)); -} - template ::value> static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,