diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td @@ -18,24 +18,24 @@ include "mlir/Dialect/AffineOps/AffineOps.td" def HasNoLinalgTransformMarker : CPred<[{ - !$0.getAttrOfType(LinalgTransforms::kLinalgTransformMarker) + !op.getAttrOfType(LinalgTransforms::kLinalgTransformMarker) }]>; class HasLinalgTransformMarker : CPred<[{ - $0.getAttrOfType( + op.getAttrOfType( LinalgTransforms::kLinalgTransformMarker) && - $0.getAttrOfType( + op.getAttrOfType( LinalgTransforms::kLinalgTransformMarker).getValue() == "}] # str # [{"}]>; class IsProducedByOpOfType : - CPred<"isProducedByOpOfType<" # str # ">($0, $1)">; + CPred<"isProducedByOpOfType<" # str # ">(op, $0)">; class AffineMapDomainHasDim : CPred<[{ - $0.getAttrOfType(getIndexingMapsAttrName()).getValue()[0]. + op.getAttrOfType(getIndexingMapsAttrName()).getValue()[0]. cast().getValue().getNumDims() ==}] # n # [{}]>; class HasOperandsOfType: CPred<[{ - llvm::any_of($0.getOperands(), + llvm::any_of(op.getOperands(), [](Value v) { return dyn_cast_or_null<}] # type # [{>(v->getDefiningOp()); }) @@ -50,7 +50,7 @@ // patterns. class TileAndFuseLinalgOp< list sizes, list operandIndices, string value> : NativeCodeCall< - "if (failed(tileAndFuseLinalgOpAndSetMarker($_builder, $0, {" # + "if (failed(tileAndFuseLinalgOpAndSetMarker($_builder, op, {" # StrJoinInt.result # "}, {" # StrJoinInt.result # "}," # " \"" # value # "\")))" # " return matchFailure();">; @@ -67,7 +67,7 @@ // of elements as `sizes`. class TileLinalgOp sizes, string value, list permutation=[]> : NativeCodeCall< - "if (failed(tileLinalgOpAndSetMarker($_builder, $0, {" # + "if (failed(tileLinalgOpAndSetMarker($_builder, op, {" # StrJoinInt.result # "}, \"" # value # "\", {" # StrJoinInt.result # "})))" # " return matchFailure();">; @@ -76,18 +76,18 @@ // Linalg to loop patterns. //===----------------------------------------------------------------------===// class LinalgOpToLoops : NativeCodeCall< - "if (failed(linalgOpToLoops<" # OpType # ">($_builder, $0))) " # + "if (failed(linalgOpToLoops<" # OpType # ">($_builder, op))) " # " return matchFailure();">; class LinalgOpToAffineLoops : NativeCodeCall< - "if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, $0))) " # + "if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, op))) " # " return matchFailure();">; //===----------------------------------------------------------------------===// // Linalg to vector contraction patterns. //===----------------------------------------------------------------------===// -class LinalgOpToVectorContraction : NativeCodeCall< - "if (failed(vectorizeGenericOp($_builder, $0))) " # +class VectorizeGenericLinalgOp : NativeCodeCall< + "if (failed(vectorizeGenericLinalgOp($_builder, op))) " # " return matchFailure();">; //===----------------------------------------------------------------------===// @@ -95,14 +95,14 @@ //===----------------------------------------------------------------------===// class PermuteGenericLinalgOp permutation, string value> : NativeCodeCall< - "if (failed(permuteGenericLinalgOp($_builder, $0, {" # + "if (failed(permuteGenericLinalgOp($_builder, op, {" # StrJoinInt.result # "}, \"" # value # "\"))) " # " return matchFailure();">; //===----------------------------------------------------------------------===// // Linalg promote subview operands. //===----------------------------------------------------------------------===// -class LinalgOpPromoteSubviews : NativeCodeCall< - "if (failed(linalgOpPromoteSubviews($_builder, $0))) " # +class PromoteSubviewsLinalgOp : NativeCodeCall< + "if (failed(promoteSubviewsLinalgOp($_builder, op))) " # " return matchFailure();">; #endif // LINALG_TRANSFORMS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h @@ -79,7 +79,8 @@ LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, Operation *op); /// Rewrite a linalg.generic into a suitable vector.contraction op. -LogicalResult vectorizeGenericOp(PatternRewriter &rewriter, Operation *op); +LogicalResult vectorizeGenericLinalgOp(PatternRewriter &rewriter, + Operation *op); /// Emits a `generic` or `indexed_generic` operation with the `indexing_maps` /// and `iterator_types` permutated according to `permutation`. @@ -88,7 +89,7 @@ StringRef linalgMarker); /// Promote std.subviews feeding linalg operations -LogicalResult linalgOpPromoteSubviews(PatternRewriter &rewriter, Operation *op); +LogicalResult promoteSubviewsLinalgOp(PatternRewriter &rewriter, Operation *op); } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -153,8 +153,8 @@ genericOp.indexing_maps() == maps && hasMultiplyAddBody(genericOp); } -LogicalResult mlir::linalg::vectorizeGenericOp(PatternRewriter &rewriter, - Operation *op) { +LogicalResult mlir::linalg::vectorizeGenericLinalgOp(PatternRewriter &rewriter, + Operation *op) { LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Rewrite linalg op as vector.contract: " << *op << ":\n"); @@ -223,7 +223,7 @@ return success(); } -LogicalResult mlir::linalg::linalgOpPromoteSubviews(PatternRewriter &rewriter, +LogicalResult mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter, Operation *op) { LinalgOp linOp = dyn_cast(op); SetVector subViews; diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td --- a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td +++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td @@ -19,11 +19,11 @@ //===----------------------------------------------------------------------===// // Test Linalg fusion patterns. //===----------------------------------------------------------------------===// -def : Pat<(MatmulOp:$consumer $A, $B, $C), - (TileAndFuseLinalgOp<[100, 150], [0], "L1"> $consumer), +def : Pat<(MatmulOp:$op $A, $_, $_), + (TileAndFuseLinalgOp<[100, 150], [0], "L1">), [ - (Constraint $consumer), - (Constraint> $consumer, $A), + (Constraint), + (Constraint> $A), ], // In the buffer world there is no use-def chains or dags so benefits // cannot be computed automatically from the length of the matched @@ -36,91 +36,91 @@ //===----------------------------------------------------------------------===// // Linalg tiling patterns. //===----------------------------------------------------------------------===// -def : Pat<(MatmulOp:$op $A, $B, $C), - (TileLinalgOp<[2000, 3000, 4000], "L3"> $op), +def : Pat<(MatmulOp:$op $_, $_, $_), + (TileLinalgOp<[2000, 3000, 4000], "L3">), [(Constraint]>> $op)]>; -def : Pat<(MatmulOp:$op $A, $B, $C), - (TileLinalgOp<[200, 300, 400], "L2"> $op), - [(Constraint> $op)]>; -def : Pat<(MatmulOp:$op $A, $B, $C), - (TileLinalgOp<[20, 30, 40], "L1"> $op), - [(Constraint> $op)]>; -def : Pat<(MatmulOp:$op $A, $B, $C), - (TileLinalgOp<[2, 3, 4], "REG"> $op), - [(Constraint> $op)]>; + HasLinalgTransformMarker<"MEM">]>>)]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (TileLinalgOp<[200, 300, 400], "L2">), + [(Constraint>)]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (TileLinalgOp<[20, 30, 40], "L1">), + [(Constraint>)]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (TileLinalgOp<[2, 3, 4], "REG">), + [(Constraint>)]>; -def : Pattern<(MatvecOp:$op $A, $b, $c), - [(TileLinalgOp<[5, 6], "L1"> $op)], - [(Constraint $op)]>; +def : Pattern<(MatvecOp:$op $_, $_, $_), + [(TileLinalgOp<[5, 6], "L1">)], + [(Constraint)]>; -def : Pattern<(DotOp:$op $a, $b, $c), - [(TileLinalgOp<[8000], "L1"> $op)], +def : Pattern<(DotOp:$op $_, $_, $_), + [(TileLinalgOp<[8000], "L1">)], [(Constraint, HasLinalgTransformMarker<"L3">, - HasLinalgTransformMarker<"L2">]>> $op)]>; -def : Pattern<(DotOp:$op $a, $b, $c), - [(TileLinalgOp<[8], "REG"> $op)], - [(Constraint> $op)]>; + HasLinalgTransformMarker<"L2">]>>)]>; +def : Pattern<(DotOp:$op $_, $_, $_), + [(TileLinalgOp<[8], "REG">)], + [(Constraint>)]>; //===----------------------------------------------------------------------===// // Linalg tiling and permutation patterns. //===----------------------------------------------------------------------===// -def : Pat<(MatmulOp:$op $A, $B, $C), - (TileLinalgOp<[2000, 3000, 4000], "L2__with_perm__", [1,2,0]> $op), - [(Constraint> $op)]>; -def : Pat<(MatmulOp:$op $A, $B, $C), - (TileLinalgOp<[200, 300, 400], "L1__with_perm__", [1,0,2]> $op), - [(Constraint> $op)]>; -def : Pat<(MatmulOp:$op $A, $B, $C), - (TileLinalgOp<[20, 30, 40], "REG__with_perm__"> $op), - [(Constraint> $op)]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (TileLinalgOp<[2000, 3000, 4000], "L2__with_perm__", [1,2,0]>), + [(Constraint>)]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (TileLinalgOp<[200, 300, 400], "L1__with_perm__", [1,0,2]>), + [(Constraint>)]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (TileLinalgOp<[20, 30, 40], "REG__with_perm__">), + [(Constraint>)]>; -def : Pattern<(MatvecOp:$op $A, $b, $c), - [(TileLinalgOp<[5, 6], "L1__with_perm__", [1,0]> $op)], - [(Constraint> $op)]>; +def : Pattern<(MatvecOp:$op $_, $_, $_), + [(TileLinalgOp<[5, 6], "L1__with_perm__", [1,0]>)], + [(Constraint>)]>; -def : Pattern<(DotOp:$op $a, $b, $c), - [(TileLinalgOp<[8000], "L1__with_perm__"> $op)], - [(Constraint> $op)]>; -def : Pattern<(DotOp:$op $a, $b, $c), - [(TileLinalgOp<[8], "REG__with_perm__"> $op)], - [(Constraint> $op)]>; +def : Pattern<(DotOp:$op $_, $_, $_), + [(TileLinalgOp<[8000], "L1__with_perm__">)], + [(Constraint>)]>; +def : Pattern<(DotOp:$op $_, $_, $_), + [(TileLinalgOp<[8], "REG__with_perm__">)], + [(Constraint>)]>; //===----------------------------------------------------------------------===// // Linalg to loops patterns. //===----------------------------------------------------------------------===// -def : Pattern<(DotOp:$op $a, $b, $c), - [(LinalgOpToLoops<"DotOp"> $op)], - [(Constraint> $op)]>; +def : Pattern<(DotOp:$op $_, $_, $_), + [(LinalgOpToLoops<"DotOp">)], + [(Constraint>)]>; //===----------------------------------------------------------------------===// // Linalg to vector contraction patterns. //===----------------------------------------------------------------------===// -def : Pattern<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8), - [(LinalgOpToVectorContraction<"GenericOp"> $op)], - [(Constraint> $op)]>; +def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_), + [(VectorizeGenericLinalgOp<"GenericOp">)], + [(Constraint>)]>; //===----------------------------------------------------------------------===// // Linalg generic permutation patterns. //===----------------------------------------------------------------------===// -def : Pat<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8), - (PermuteGenericLinalgOp<[1,2,0],"PERMUTED"> $op), +def : Pat<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_), + (PermuteGenericLinalgOp<[1,2,0],"PERMUTED">), [(Constraint]>> $op)]>; + AffineMapDomainHasDim<3>]>>)]>; -def : Pat<(IndexedGenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8), - (PermuteGenericLinalgOp<[1,2,0],"PERMUTED"> $op), +def : Pat<(IndexedGenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_), + (PermuteGenericLinalgOp<[1,2,0],"PERMUTED">), [(Constraint]>> $op)]>; + AffineMapDomainHasDim<3>]>>)]>; //===----------------------------------------------------------------------===// // Linalg subview operands promotion. //===----------------------------------------------------------------------===// -def : Pat<(MatmulOp:$op $A, $B, $C), - (LinalgOpPromoteSubviews<"MatmulOp"> $op), - [(Constraint> $op), - (Constraint> $op)]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (PromoteSubviewsLinalgOp<"MatmulOp">), + [(Constraint>), + (Constraint>)]>; #endif // TEST_LINALG_TRANSFORMS_PATTERNS