diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1193,41 +1193,7 @@ [NoSideEffect, ViewLikeOpInterface])>, Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>, Results<(outs AnyStridedMemRef:$result)>{ - let builders = [ - // Builders for a contracting reshape whose result type is computed from - // `src` and `reassociation`. - OpBuilder<(ins "Value":$src, - "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs)>, - OpBuilder<(ins "Value":$src, - "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), - [{ - auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); - build($_builder, $_state, src, reassociationMaps, attrs); - }]>, - - // Builders for a reshape whose result type is passed explicitly. This may - // be either a contracting or expanding reshape. - OpBuilder<(ins "Type":$resultType, "Value":$src, - "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), - [{ - build($_builder, $_state, resultType, src, attrs); - $_state.addAttribute("reassociation", - getReassociationIndicesAttribute($_builder, reassociation)); - }]>, - OpBuilder<(ins "Type":$resultType, "Value":$src, - "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), - [{ - auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); - build($_builder, $_state, resultType, src, reassociationMaps, attrs); - }]> - ]; - + code commonExtraClassDeclaration = [{ SmallVector getReassociationMaps(); SmallVector getReassociationExprs(); @@ -1288,6 +1254,25 @@ memref into memref ``` }]; + let builders = [ + // Builders using ReassociationIndices. + OpBuilder<(ins "Type":$resultType, "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + build($_builder, $_state, resultType, src, attrs); + $_state.addAttribute("reassociation", + getReassociationIndicesAttribute($_builder, reassociation)); + }]>, + OpBuilder<(ins "Type":$resultType, "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + auto reassociationMaps = + convertReassociationMapsToIndices($_builder, reassociation); + build($_builder, $_state, resultType, src, reassociationMaps, attrs); + }]> + ]; let extraClassDeclaration = commonExtraClassDeclaration; let hasVerifier = 1; } @@ -1326,6 +1311,39 @@ memref into memref ``` }]; + let builders = [ + // Builders for a contracting reshape whose result type is computed from + // `src` and `reassociation`. + OpBuilder<(ins "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs)>, + OpBuilder<(ins "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + auto reassociationMaps = + convertReassociationMapsToIndices($_builder, reassociation); + build($_builder, $_state, src, reassociationMaps, attrs); + }]>, + + // Builders for a reshape whose result type is passed explicitly. + OpBuilder<(ins "Type":$resultType, "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + build($_builder, $_state, resultType, src, attrs); + $_state.addAttribute("reassociation", + getReassociationIndicesAttribute($_builder, reassociation)); + }]>, + OpBuilder<(ins "Type":$resultType, "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + auto reassociationMaps = + convertReassociationMapsToIndices($_builder, reassociation); + build($_builder, $_state, resultType, src, reassociationMaps, attrs); + }]> + ]; let extraClassDeclaration = commonExtraClassDeclaration; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -678,41 +678,7 @@ Tensor_Op, Arguments<(ins AnyTensor:$src, IndexListArrayAttr:$reassociation)>, Results<(outs AnyTensor:$result)> { - let builders = [ - // Builders for a contracting reshape whose result type is computed from - // `src` and `reassociation`. - OpBuilder<(ins "Value":$src, - "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs)>, - OpBuilder<(ins "Value":$src, - "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), - [{ - auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); - build($_builder, $_state, src, reassociationMaps, attrs); - }]>, - - // Builders for a reshape whose result type is passed explicitly. This may - // be either a contracting or expanding reshape. - OpBuilder<(ins "Type":$resultType, "Value":$src, - "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), - [{ - build($_builder, $_state, resultType, src, attrs); - $_state.addAttribute("reassociation", - getReassociationIndicesAttribute($_builder, reassociation)); - }]>, - OpBuilder<(ins "Type":$resultType, "Value":$src, - "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), - [{ - auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); - build($_builder, $_state, resultType, src, reassociationMaps, attrs); - }]> - ]; - + code commonExtraClassDeclaration = [{ static StringRef getReassociationAttrName() { return "reassociation"; } SmallVector getReassociationMaps(); @@ -768,6 +734,26 @@ : tensor into tensor ``` }]; + let builders = [ + // Builders using ReassociationIndices. + OpBuilder<(ins "Type":$resultType, "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + build($_builder, $_state, resultType, src, attrs); + $_state.addAttribute("reassociation", + getReassociationIndicesAttribute($_builder, reassociation)); + }]>, + OpBuilder<(ins "Type":$resultType, "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + auto reassociationMaps = + convertReassociationMapsToIndices($_builder, reassociation); + build($_builder, $_state, resultType, src, reassociationMaps, attrs); + }]> + ]; + let extraClassDeclaration = commonExtraClassDeclaration; let hasVerifier = 1; } @@ -797,6 +783,40 @@ : tensor into tensor ``` }]; + let builders = [ + // Builders for a contracting reshape whose result type is computed from + // `src` and `reassociation`. + OpBuilder<(ins "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs)>, + OpBuilder<(ins "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + auto reassociationMaps = + convertReassociationMapsToIndices($_builder, reassociation); + build($_builder, $_state, src, reassociationMaps, attrs); + }]>, + + // Builders for a reshape whose result type is passed explicitly. + OpBuilder<(ins "Type":$resultType, "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + build($_builder, $_state, resultType, src, attrs); + $_state.addAttribute("reassociation", + getReassociationIndicesAttribute($_builder, reassociation)); + }]>, + OpBuilder<(ins "Type":$resultType, "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + auto reassociationMaps = + convertReassociationMapsToIndices($_builder, reassociation); + build($_builder, $_state, resultType, src, reassociationMaps, attrs); + }]> + ]; + let extraClassDeclaration = commonExtraClassDeclaration; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -2223,12 +2223,11 @@ void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns( RewritePatternSet &patterns) { - patterns - .add, - FoldProducerReshapeOpByLinearization, - FoldConsumerReshapeOpByLinearization, - FoldConsumerReshapeOpByLinearization>( - patterns.getContext()); + patterns.add< + FoldProducerReshapeOpByLinearization, + FoldProducerReshapeOpByLinearization, + FoldConsumerReshapeOpByLinearization>( + patterns.getContext()); } void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns( @@ -2236,8 +2235,7 @@ patterns .add, FoldProducerReshapeOpByLinearization, - FoldConsumerReshapeOpByLinearization, - FoldConsumerReshapeOpByLinearization>( + FoldConsumerReshapeOpByLinearization>( patterns.getContext()); } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1669,18 +1669,6 @@ AffineMapAttr::get(layout))); } -void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src, - ArrayRef reassociation, - ArrayRef attrs) { - auto memRefType = src.getType().cast(); - auto resultType = computeReshapeCollapsedType( - memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( - b.getContext(), reassociation))); - build(b, result, resultType, src, attrs); - result.addAttribute(getReassociationAttrName(), - getReassociationIndicesAttribute(b, reassociation)); -} - void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, ArrayRef reassociation, ArrayRef attrs) { diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -817,18 +817,6 @@ getReassociationIndicesAttribute(b, reassociation)); } -void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src, - ArrayRef reassociation, - ArrayRef attrs) { - auto resultType = computeTensorReshapeCollapsedType( - src.getType().cast(), - getSymbolLessAffineMaps( - convertReassociationIndicesToExprs(b.getContext(), reassociation))); - build(b, result, resultType, src, attrs); - result.addAttribute(getReassociationAttrName(), - getReassociationIndicesAttribute(b, reassociation)); -} - template ::value> static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,