diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -19,6 +19,10 @@ include "mlir/IR/OpBase.td" include "mlir/IR/RegionKindInterface.td" +//===----------------------------------------------------------------------===// +// DecomposeOp +//===----------------------------------------------------------------------===// + def DecomposeOp : Op { @@ -48,6 +52,10 @@ }]; } +//===----------------------------------------------------------------------===// +// FuseOp +//===----------------------------------------------------------------------===// + def FuseOp : Op]> { @@ -67,6 +75,10 @@ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// FuseIntoContainingOp +//===----------------------------------------------------------------------===// + def FuseIntoContainingOp : Op]> { @@ -120,6 +132,10 @@ ]; } +//===----------------------------------------------------------------------===// +// GeneralizeOp +//===----------------------------------------------------------------------===// + def GeneralizeOp : Op { @@ -149,6 +165,10 @@ }]; } +//===----------------------------------------------------------------------===// +// InterchangeOp +//===----------------------------------------------------------------------===// + def InterchangeOp : Op { @@ -169,10 +189,14 @@ let arguments = (ins PDL_Operation:$target, - DefaultValuedAttr:$iterator_interchange); + ConfinedAttr, + [DenseArrayNonNegative]>:$iterator_interchange); let results = (outs PDL_Operation:$transformed); - let assemblyFormat = "$target attr-dict"; + let assemblyFormat = [{ + $target + (`iterator_interchange` `=` $iterator_interchange^)? attr-dict + }]; let hasVerifier = 1; let extraClassDeclaration = [{ @@ -183,6 +207,10 @@ }]; } +//===----------------------------------------------------------------------===// +// MatchOp +//===----------------------------------------------------------------------===// + def MatchInterfaceEnum : I32EnumAttr<"MatchInterfaceEnum", "An interface to match", [ I32EnumAttrCase<"LinalgOp", 0>, @@ -245,6 +273,10 @@ }]; } +//===----------------------------------------------------------------------===// +// MultiTileSizesOp +//===----------------------------------------------------------------------===// + def MultiTileSizesOp : Op, TransformOpInterface, TransformEachOpTrait]> { @@ -309,6 +341,10 @@ }]; } +//===----------------------------------------------------------------------===// +// PadOp +//===----------------------------------------------------------------------===// + def PadOp : Op { @@ -349,6 +385,10 @@ }]; } +//===----------------------------------------------------------------------===// +// PromoteOp +//===----------------------------------------------------------------------===// + def PromoteOp : Op { @@ -388,6 +428,10 @@ }]; } +//===----------------------------------------------------------------------===// +// ReplaceOp +//===----------------------------------------------------------------------===// + def ReplaceOp : Op, DeclareOpInterfaceMethods] # GraphRegionNoTerminator.traits> { @@ -410,6 +454,10 @@ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// ScalarizeOp +//===----------------------------------------------------------------------===// + def ScalarizeOp : Op { @@ -449,6 +497,10 @@ }]; } +//===----------------------------------------------------------------------===// +// SplitOp +//===----------------------------------------------------------------------===// + def SplitOp : Op, DeclareOpInterfaceMethods]> { @@ -481,6 +533,10 @@ let hasCustomAssemblyFormat = 1; } +//===----------------------------------------------------------------------===// +// SplitReductionOp +//===----------------------------------------------------------------------===// + def SplitReductionOp : Op { @@ -649,6 +705,10 @@ }]; } +//===----------------------------------------------------------------------===// +// TileReductionUsingScfOp +//===----------------------------------------------------------------------===// + def TileReductionUsingScfOp : Op { @@ -738,6 +798,10 @@ }]; } +//===----------------------------------------------------------------------===// +// TileReductionUsingForeachThreadOp +//===----------------------------------------------------------------------===// + def TileReductionUsingForeachThreadOp : Op, DeclareOpInterfaceMethods]> { @@ -884,6 +952,10 @@ }]; } +//===----------------------------------------------------------------------===// +// TileToForeachThreadOp +//===----------------------------------------------------------------------===// + def TileToForeachThreadOp : Op, DeclareOpInterfaceMethods]> { @@ -1054,6 +1130,10 @@ }]; } +//===----------------------------------------------------------------------===// +// VectorizeOp +//===----------------------------------------------------------------------===// + def VectorizeOp : Op { diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -34,16 +34,6 @@ #define DEBUG_TYPE "linalg-transforms" -/// Extracts a vector of unsigned from an array attribute. Asserts if the -/// attribute contains values other than intergers. May truncate. -static SmallVector extractUIntArray(ArrayAttr attr) { - SmallVector result; - result.reserve(attr.size()); - for (APInt value : attr.getAsValueRange()) - result.push_back(value.getZExtValue()); - return result; -} - /// Attempts to apply the pattern specified as template argument to the given /// operation. The pattern is expected to have a `returningMatchAndRewrite` /// function that returns the "main" result or failure. Returns failure if the @@ -604,8 +594,7 @@ transform::InterchangeOp::applyToOne(linalg::GenericOp target, SmallVectorImpl &results, transform::TransformState &state) { - SmallVector interchangeVector = - extractUIntArray(getIteratorInterchange()); + ArrayRef interchangeVector = getIteratorInterchange(); // Exit early if no transformation is needed. if (interchangeVector.empty()) { results.push_back(target); @@ -613,7 +602,9 @@ } TrivialPatternRewriter rewriter(target->getContext()); FailureOr res = - interchangeGenericOp(rewriter, target, interchangeVector); + interchangeGenericOp(rewriter, target, + SmallVector(interchangeVector.begin(), + interchangeVector.end())); if (failed(res)) return DiagnosedSilenceableFailure::definiteFailure(); results.push_back(res->getOperation()); @@ -621,9 +612,8 @@ } LogicalResult transform::InterchangeOp::verify() { - SmallVector permutation = - extractUIntArray(getIteratorInterchange()); - auto sequence = llvm::to_vector(llvm::seq(0, permutation.size())); + ArrayRef permutation = getIteratorInterchange(); + auto sequence = llvm::to_vector(llvm::seq(0, permutation.size())); if (!std::is_permutation(sequence.begin(), sequence.end(), permutation.begin(), permutation.end())) { return emitOpError() diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -257,7 +257,7 @@ } //===----------------------------------------------------------------------===// -// ForeachOp +// CastOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure diff --git a/mlir/test/Dialect/Linalg/transform-op-interchange.mlir b/mlir/test/Dialect/Linalg/transform-op-interchange.mlir --- a/mlir/test/Dialect/Linalg/transform-op-interchange.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-interchange.mlir @@ -21,7 +21,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - transform.structured.interchange %0 { iterator_interchange = [1, 0]} + transform.structured.interchange %0 iterator_interchange = [1, 0] } // ----- @@ -36,5 +36,5 @@ ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 // expected-error @below {{transform applied to the wrong op kind}} - transform.structured.interchange %0 { iterator_interchange = [1, 0]} + transform.structured.interchange %0 iterator_interchange = [1, 0] } diff --git a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir --- a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir +++ b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir @@ -2,8 +2,8 @@ transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): - // expected-error@below {{expects iterator_interchange to be a permutation, found [1, 1]}} - transform.structured.interchange %arg0 {iterator_interchange = [1, 1]} + // expected-error@below {{'transform.structured.interchange' op expects iterator_interchange to be a permutation, found 1, 1}} + transform.structured.interchange %arg0 iterator_interchange = [1, 1] } // ----- @@ -37,3 +37,11 @@ // expected-error@below {{expects transpose_paddings to be a permutation, found [1, 1]}} transform.structured.pad %arg0 {transpose_paddings=[[1, 1]]} } + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg0: !pdl.operation): + // expected-error@below {{'transform.structured.interchange' op attribute 'iterator_interchange' failed to satisfy constraint: i64 dense array attribute whose value is non-negative}} + transform.structured.interchange %arg0 iterator_interchange = [-3, 1] +} diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -138,7 +138,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - transform.structured.interchange %0 {iterator_interchange = [1, 2, 0]} + transform.structured.interchange %0 iterator_interchange = [1, 2, 0] } // CHECK-LABEL: func @permute_generic