diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h @@ -9,9 +9,16 @@ #ifndef MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H #define MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H +#include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/IR/OpImplementation.h" +namespace mlir { +namespace linalg { +class LinalgOp; +} // namespace linalg +} // namespace mlir + //===----------------------------------------------------------------------===// // Linalg Transform Operations //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -16,6 +16,81 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpBase.td" +def InterchangeOp : Op { + let description = [{ + Interchanges the iterators of the operations pointed to by the target handle + using the iterator interchange attribute. + }]; + + let arguments = + (ins PDL_Operation:$target, + DefaultValuedAttr:$iterator_interchange); + let results = (outs PDL_Operation:$transformed); + + let assemblyFormat = "$target attr-dict"; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne( + ::mlir::linalg::LinalgOp target); + }]; +} + +def PadOp : Op { + let description = [{ + Pads the operations pointed to by the target handle using the options + provides as operation attributes. + }]; + + let arguments = + (ins PDL_Operation:$target, + DefaultValuedAttr:$padding_values, + DefaultValuedAttr:$padding_dimensions, + DefaultValuedAttr:$pack_paddings, + DefaultValuedAttr:$hoist_paddings, + DefaultValuedAttr< + TypedArrayAttrBase, + "{}">:$transpose_paddings); + let results = (outs PDL_Operation:$transformed); + + let assemblyFormat = "$target attr-dict"; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne( + ::mlir::linalg::LinalgOp target); + }]; +} + +def ScalarizeOp : Op { + let description = [{ + Indicates that ops of a specific kind in the given function should be + scalarized (i.e. their dynamic dimensions tiled by 1). + + This operation returns the tiled op but not the loops. + + We make this design choice because it is hard to know ahead of time the + number of loops that will be produced (it depends on the number of dynamic + dimensions after multiple transformations have been applied). + }]; + + let arguments = (ins PDL_Operation:$target); + let results = (outs PDL_Operation:$result); + + let assemblyFormat = "$target attr-dict"; + + let extraClassDeclaration = [{ + ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne( + ::mlir::linalg::LinalgOp target); + }]; +} + def TileOp : Op, DeclareOpInterfaceMethods]> { diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -397,6 +397,31 @@ } }; +/// Trait implementing the TransformOpInterface for operations applying a +/// transformation to a single operation handle and producing a single operation +/// handle. The op must implement a method with one of the following signatures: +/// - FailureOr applyToOne(OpTy) +/// - LogicalResult applyToOne(OpTy) +/// to perform a transformation that is applied in turn to all payload IR +/// operations that correspond to the handle of the transform IR operation. +/// In the functions above, OpTy is either Operation * or a concrete payload IR +/// Op class that the transformation is applied to (NOT the class of the +/// transform IR op). The op is expected to have one operand and zero or one +/// results. +template +class TransformEachOpTrait + : public OpTrait::TraitBase { +public: + /// Calls `applyToOne` for every payload operation associated with the operand + /// of this transform IR op. If `applyToOne` returns ops, associates them with + /// the result of this transform op. + LogicalResult apply(TransformResults &transformResults, + TransformState &state); + + /// Checks that the op matches the expectations of this trait. + static LogicalResult verifyTrait(Operation *op); +}; + /// Side effect resource corresponding to the mapping between Transform IR /// values and Payload IR operations. An Allocate effect from this resource /// means creating a new mapping entry, it is always accompanied by a Write @@ -426,9 +451,150 @@ StringRef getName() override { return "transform.payload_ir"; } }; +/// Trait implementing the MemoryEffectOpInterface for single-operand +/// single-result operations that "consume" their operand and produce a new +/// result. +template +class FunctionalStyleTransformOpTrait + : public OpTrait::TraitBase { +public: + /// This op "consumes" the operand by reading and freeing it, "produces" the + /// result by allocating and writing it and reads/writes the payload IR in the + /// process. + void getEffects(SmallVectorImpl &effects) { + effects.emplace_back(MemoryEffects::Read::get(), + this->getOperation()->getOperand(0), + TransformMappingResource::get()); + effects.emplace_back(MemoryEffects::Free::get(), + this->getOperation()->getOperand(0), + TransformMappingResource::get()); + effects.emplace_back(MemoryEffects::Allocate::get(), + this->getOperation()->getResult(0), + TransformMappingResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), + this->getOperation()->getResult(0), + TransformMappingResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get()); + } + + /// Checks that the op matches the expectations of this trait. + static LogicalResult verifyTrait(Operation *op) { + static_assert(OpTy::template hasTrait(), + "expected single-operand op"); + static_assert(OpTy::template hasTrait(), + "expected single-result op"); + if (!op->getName().getInterface()) { + op->emitError() + << "FunctionalStyleTransformOpTrait should only be attached to ops " + "that implement MemoryEffectOpInterface"; + } + return success(); + } +}; + } // namespace transform } // namespace mlir #include "mlir/Dialect/Transform/IR/TransformInterfaces.h.inc" +namespace mlir { +namespace transform { +namespace detail { +/// Appends `result` to the vector assuming it corresponds to the success state +/// in `FailureOr`. If `result` is just a +/// `LogicalResult`, does nothing. +template +std::enable_if_t::value, LogicalResult> +appendTransformResultToVector(Ty result, + SmallVectorImpl &results) { + return result; +} +template +std::enable_if_t::value, LogicalResult> +appendTransformResultToVector(Ty result, + SmallVectorImpl &results) { + static_assert( + std::is_convertible::value, + "expected transform function to return operations"); + if (failed(result)) + return failure(); + + results.push_back(*result); + return success(); +} + +/// Applies a one-to-one transform to each of the given targets. Puts the +/// results of transforms, if any, in `results` in the same order. Fails if any +/// of the application fails. Individual transforms must be callable with +/// one of the following signatures: +/// - FailureOr(OpTy) +/// - LogicalResult(OpTy) +/// where OpTy is either +/// - Operation *, in which case the transform is always applied; +/// - a concrete Op class, in which case a check is performed whether +/// `targets` contains operations of the same class and a failure is reported +/// if it does not. +template +LogicalResult applyTransformToEach(ArrayRef targets, + SmallVectorImpl &results, + FnTy transform) { + using OpTy = typename llvm::function_traits::template arg_t<0>; + static_assert(std::is_convertible::value, + "expected transform function to take an operation"); + using RetTy = typename llvm::function_traits::result_t; + static_assert(std::is_convertible::value, + "expected transform function to return LogicalResult or " + "FailureOr"); + for (Operation *target : targets) { + auto specificOp = dyn_cast(target); + if (!specificOp) + return failure(); + + auto result = transform(specificOp); + if (failed(appendTransformResultToVector(result, results))) + return failure(); + } + return success(); +} +} // namespace detail +} // namespace transform +} // namespace mlir + +template +mlir::LogicalResult mlir::transform::TransformEachOpTrait::apply( + TransformResults &transformResults, TransformState &state) { + using TransformOpType = typename llvm::function_traits< + decltype(&OpTy::applyToOne)>::template arg_t<0>; + ArrayRef targets = + state.getPayloadOps(this->getOperation()->getOperand(0)); + SmallVector results; + if (failed(detail::applyTransformToEach( + targets, results, [&](TransformOpType specificOp) { + return static_cast(this)->applyToOne(specificOp); + }))) + return failure(); + if (OpTy::template hasTrait()) { + transformResults.set( + this->getOperation()->getResult(0).template cast(), results); + } + return success(); +} + +template +mlir::LogicalResult +mlir::transform::TransformEachOpTrait::verifyTrait(Operation *op) { + static_assert(OpTy::template hasTrait(), + "expected single-operand op"); + static_assert(OpTy::template hasTrait() || + OpTy::template hasTrait(), + "expected zero- or single-result op"); + if (!op->getName().getInterface()) { + return op->emitError() << "TransformEachOpTrait should only be attached to " + "ops that implement TransformOpInterface"; + } + + return success(); +} + #endif // DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -49,4 +49,13 @@ ]; } +def FunctionalStyleTransformOpTrait + : NativeOpTrait<"FunctionalStyleTransformOpTrait"> { + let cppNamespace = "::mlir::transform"; +} + +def TransformEachOpTrait : NativeOpTrait<"TransformEachOpTrait"> { + let cppNamespace = "::mlir::transform"; +} + #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_INTERFACES_TD diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -49,6 +49,179 @@ }; } // namespace +//===----------------------------------------------------------------------===// +// InterchangeOp +//===----------------------------------------------------------------------===// + +FailureOr transform::InterchangeOp::applyToOne(LinalgOp target) { + SmallVector interchangeVector = + extractUIntArray(getIteratorInterchange()); + // Exit early if no transformation is needed. + if (interchangeVector.empty()) + return target; + + auto genericTarget = dyn_cast(target.getOperation()); + if (!genericTarget) { + InFlightDiagnostic diag = emitOpError() + << "applies to " << GenericOp::getOperationName() + << " ops"; + diag.attachNote(target.getLoc()) << "attempted to apply to this op"; + return diag; + } + + GenericOpInterchangePattern pattern(getContext(), interchangeVector); + SimpleRewriter rewriter(getContext()); + rewriter.setInsertionPoint(target); + FailureOr result = + pattern.returningMatchAndRewrite(genericTarget, rewriter); + if (failed(result)) + return failure(); + + return cast(result->getOperation()); +} + +LogicalResult transform::InterchangeOp::verify() { + SmallVector permutation = + extractUIntArray(getIteratorInterchange()); + auto sequence = llvm::to_vector(llvm::seq(0, permutation.size())); + if (!std::is_permutation(sequence.begin(), sequence.end(), + permutation.begin(), permutation.end())) { + return emitOpError() + << "expects iterator_interchange to be a permutation, found " + << getIteratorInterchange(); + } + return success(); +} + +//===---------------------------------------------------------------------===// +// PadOp +//===---------------------------------------------------------------------===// + +FailureOr transform::PadOp::applyToOne(LinalgOp target) { + // Convert the integer packing flags to booleans. + SmallVector packPaddings; + for (int64_t packPadding : extractI64Array(getPackPaddings())) + packPaddings.push_back(static_cast(packPadding)); + + // Convert the padding values to attributes. + SmallVector paddingValues; + for (auto const &it : + llvm::zip(getPaddingValues(), target->getOperandTypes())) { + Attribute attr = std::get<0>(it); + Type elementType = getElementTypeOrSelf(std::get<1>(it)); + // Try to parse string attributes to obtain an attribute of element type. + if (auto stringAttr = attr.dyn_cast()) { + paddingValues.push_back( + parseAttribute(attr.cast(), elementType)); + if (!paddingValues.back()) { + InFlightDiagnostic diag = emitOpError() + << "expects a padding value that parses to " + << elementType << ", got " << std::get<0>(it); + diag.attachNote(target.getLoc()) << "when applied to this op"; + return diag; + } + continue; + } + // Otherwise, add the attribute directly. + if (attr.getType() != elementType) { + InFlightDiagnostic diag = emitOpError() + << "expects a padding value of type " + << elementType << ", got " << attr; + diag.attachNote(target.getLoc()) << "when applied to this op"; + return diag; + } + paddingValues.push_back(attr); + } + + // Extract the transpose vectors. + SmallVector> transposePaddings; + for (Attribute transposeVector : getTransposePaddings().cast()) + transposePaddings.push_back( + extractI64Array(transposeVector.cast())); + + LinalgPaddingOptions paddingOptions; + paddingOptions.setPaddingValues(paddingValues); + paddingOptions.setPaddingDimensions(extractI64Array(getPaddingDimensions())); + paddingOptions.setPackPaddings(packPaddings); + paddingOptions.setHoistPaddings(extractI64Array(getHoistPaddings())); + paddingOptions.setTransposePaddings(transposePaddings); + + LinalgPaddingPattern pattern(getContext(), paddingOptions); + SimpleRewriter rewriter(getContext()); + rewriter.setInsertionPoint(target); + FailureOr patternResult = + pattern.returningMatchAndRewrite(target, rewriter); + if (failed(patternResult)) { + InFlightDiagnostic diag = emitError() + << "failed to apply pattern to target op"; + diag.attachNote(target.getLoc()) << "target op"; + return diag; + } + return patternResult; +} + +LogicalResult transform::PadOp::verify() { + SmallVector packPaddings = extractI64Array(getPackPaddings()); + if (any_of(packPaddings, [](int64_t packPadding) { + return packPadding != 0 && packPadding != 1; + })) { + return emitOpError() + << "expects pack_paddings to contain booleans (0/1), found " + << getPackPaddings(); + } + + SmallVector paddingDimensions = + extractI64Array(getPaddingDimensions()); + if (any_of(paddingDimensions, + [](int64_t paddingDimension) { return paddingDimension < 0; })) { + return emitOpError() + << "expects padding_dimensions to contain positive integers, found " + << getPaddingDimensions(); + } + + SmallVector hoistPaddings = extractI64Array(getHoistPaddings()); + if (any_of(hoistPaddings, + [](int64_t hoistPadding) { return hoistPadding < 0; })) { + return emitOpError() + << "expects hoist_paddings to contain positive integers, found " + << getHoistPaddings(); + } + + ArrayAttr transposes = getTransposePaddings(); + for (Attribute attr : transposes) { + SmallVector transpose = extractFromI64ArrayAttr(attr); + auto sequence = llvm::to_vector(llvm::seq(0, transpose.size())); + if (!std::is_permutation(sequence.begin(), sequence.end(), + transpose.begin(), transpose.end())) { + return emitOpError() + << "expects transpose_paddings to be a permutation, found " + << attr; + } + } + return success(); +} + +//===----------------------------------------------------------------------===// +// ScalarizeOp +//===----------------------------------------------------------------------===// + +FailureOr transform::ScalarizeOp::applyToOne(LinalgOp target) { + LinalgTilingOptions tilingOptions; + tilingOptions.scalarizeDynamicDims(); + // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile + // sizes and asserts that it is not already set. + SmallVector emptyTileSizes; + LinalgTilingPattern pattern(getContext(), tilingOptions); + SimpleRewriter rewriter(getContext()); + rewriter.setInsertionPoint(target); + FailureOr result = + pattern.returningMatchAndRewrite(target, rewriter); + if (failed(result)) + return failure(); + + return result->op; +} + //===----------------------------------------------------------------------===// // TileOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-op-interchange.mlir b/mlir/test/Dialect/Linalg/transform-op-interchange.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-op-interchange.mlir @@ -0,0 +1,60 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d1, d0)> + +// CHECK-LABEL: @interchange_generic +func.func @interchange_generic(%arg0: tensor, %arg1: tensor) -> tensor { + + // CHECK: linalg.generic + // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]] + %0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%arg0 : tensor) outs(%arg1 : tensor) { + ^bb0(%arg2: f32, %arg3: f32): + %1 = math.exp %arg2 : f32 + linalg.yield %1 : f32 + } -> tensor + return %0 : tensor +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @match_generic : benefit(1) { + %args = operands + %results = types + %0 = pdl.operation "linalg.generic"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %0 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_generic in %arg1 + transform.structured.interchange %0 { iterator_interchange = [1, 0]} + } +} + +// ----- + +func.func @interchange_matmul(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // expected-note @below {{attempted to apply to this op}} + %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) -> tensor + return %0 : tensor +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @match_generic : benefit(1) { + %args = operands + %results = types + %0 = pdl.operation "linalg.matmul"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %0 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_generic in %arg1 + // expected-error @below {{applies to linalg.generic ops}} + transform.structured.interchange %0 { iterator_interchange = [1, 0]} + } +} diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir @@ -0,0 +1,133 @@ +// RUN: mlir-opt --test-transform-dialect-interpreter -split-input-file -verify-diagnostics %s | FileCheck %s + +#map = affine_map<()[s0] -> (-s0 + 12, 7)> + +// CHECK-LABEL: @static_sizes_output_divisible +func.func @static_sizes_output_divisible(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>, + %iv0 : index, %iv1 : index, %iv2 : index) -> tensor<24x25xf32> { + %0 = affine.min #map()[%iv2] + + // CHECK: %[[T0:.*]] = tensor.extract_slice % + // CHECK: %[[T1:.*]] = tensor.extract_slice % + // CHECK: %[[T2:.*]] = tensor.extract_slice % + %1 = tensor.extract_slice %arg0[%iv0, %iv2] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32> + %2 = tensor.extract_slice %arg1[%iv2, %iv1] [%0, 5] [1, 1] : tensor<12x25xf32> to tensor + %3 = tensor.extract_slice %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<24x25xf32> to tensor<4x5xf32> + + // CHECK-DAG: %[[CST:.*]] = arith.constant 0. + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + + // CHECK: %[[T3:.*]] = tensor.pad %[[T0]] nofold + // CHECK: tensor.yield %[[CST]] + // CHECK: %[[T4:.*]] = tensor.pad %[[T1]] nofold + + // CHECK: %[[T5:.*]] = linalg.matmul + // CHECK-SAME: ins(%[[T3]], %[[T4]] : tensor<4x7xf32>, tensor<7x5xf32>) + // CHECK-SAME: outs(%[[T2]] : tensor<4x5xf32>) + %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32> + %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32> + func.return %5 : tensor<24x25xf32> +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = pdl.operation "linalg.matmul"(%args : !pdl.range) -> (%results : !pdl.range) + // TODO: we don't want this, but it is the required terminator for pdl.pattern + rewrite %0 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + %1 = transform.structured.pad %0 {padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0]} + } +} + +// ----- + +func.func @pad(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { + // expected-note @below {{when applied to this op}} + %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> + func.return %0 : tensor<24x25xf32> +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = pdl.operation "linalg.matmul"(%args : !pdl.range) -> (%results : !pdl.range) + // TODO: we don't want this, but it is the required terminator for pdl.pattern + rewrite %0 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + // expected-error @below {{op expects a padding value of type 'f32', got 0 : i32}} + %1 = transform.structured.pad %0 {padding_values=[0: i32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0]} + } +} + +// ----- + +func.func @pad(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { + // expected-note @below {{when applied to this op}} + %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> + func.return %0 : tensor<24x25xf32> +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = pdl.operation "linalg.matmul"(%args : !pdl.range) -> (%results : !pdl.range) + // TODO: we don't want this, but it is the required terminator for pdl.pattern + rewrite %0 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + // expected-error @below {{expects a padding value that parses to 'f32', got "foo"}} + %1 = transform.structured.pad %0 {padding_values=["foo", 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0]} + } +} + +// ----- + +func.func @pad(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { + // expected-note @below {{target op}} + %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> + func.return %0 : tensor<24x25xf32> +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = pdl.operation "linalg.matmul"(%args : !pdl.range) -> (%results : !pdl.range) + // TODO: we don't want this, but it is the required terminator for pdl.pattern + rewrite %0 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + // expected-error @below {{failed to apply pattern to target op}} + %1 = transform.structured.pad %0 {padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0]} + } +} diff --git a/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir b/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-opt -test-transform-dialect-interpreter %s | FileCheck %s + +func.func @scalarize(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { + // The op is first tiled by 10 in the first dimension, which creates a + // dynamic size, and then scalarized, which brings the dimension to static 1. + // CHECK: linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<1x12 + %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> + func.return %0 : tensor<24x25xf32> +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = pdl.operation "linalg.matmul"(%args : !pdl.range) -> (%results : !pdl.range) + // TODO: we don't want this, but it is the required terminator for pdl.pattern + rewrite %0 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + %1, %loops = transform.structured.tile %0 {sizes = [10, 0, 0]} + %2 = transform.structured.scalarize %1 + } +} diff --git a/mlir/test/Dialect/Linalg/transform-ops.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir copy from mlir/test/Dialect/Linalg/transform-ops.mlir copy to mlir/test/Dialect/Linalg/transform-op-tile.mlir diff --git a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir @@ -0,0 +1,39 @@ +// RUN: mlir-opt %s --split-input-file --verify-diagnostics + +transform.sequence { +^bb0(%arg0: !pdl.operation): + // expected-error@below {{expects iterator_interchange to be a permutation, found [1, 1]}} + transform.structured.interchange %arg0 {iterator_interchange = [1, 1]} +} + +// ----- + +transform.sequence { +^bb0(%arg0: !pdl.operation): + // expected-error@below {{expects padding_dimensions to contain positive integers, found [1, -7]}} + transform.structured.pad %arg0 {padding_dimensions=[1, -7]} +} + +// ----- + +transform.sequence { +^bb0(%arg0: !pdl.operation): + // expected-error@below {{expects pack_paddings to contain booleans (0/1), found [1, 7]}} + transform.structured.pad %arg0 {pack_paddings=[1, 7]} +} + +// ----- + +transform.sequence { +^bb0(%arg0: !pdl.operation): + // expected-error@below {{expects hoist_paddings to contain positive integers, found [1, -7]}} + transform.structured.pad %arg0 {hoist_paddings=[1, -7]} +} + +// ----- + +transform.sequence { +^bb0(%arg0: !pdl.operation): + // expected-error@below {{expects transpose_paddings to be a permutation, found [1, 1]}} + transform.structured.pad %arg0 {transpose_paddings=[[1, 1]]} +} diff --git a/mlir/test/Dialect/Linalg/transform-ops.mlir b/mlir/test/Dialect/Linalg/transform-ops.mlir --- a/mlir/test/Dialect/Linalg/transform-ops.mlir +++ b/mlir/test/Dialect/Linalg/transform-ops.mlir @@ -1,46 +1,31 @@ -// RUN: mlir-opt --test-transform-dialect-interpreter %s | FileCheck %s +// RUN: mlir-opt %s | mlir-opt | FileCheck %s -transform.with_pdl_patterns { -^bb0(%arg0: !pdl.operation): - sequence %arg0 { - ^bb0(%arg1: !pdl.operation): - %0 = pdl_match @pdl_target in %arg1 - %1, %loops:3 = transform.structured.tile %0 {sizes = [4, 4, 4]} - } - - pdl.pattern @pdl_target : benefit(1) { - %args = operands - %results = types - %0 = operation "linalg.matmul"(%args : !pdl.range) -> (%results : !pdl.range) - rewrite %0 with "transform.dialect" - } +transform.sequence { +^bb1(%arg0: !pdl.operation): + // CHECK %{{.*}}, %{{.*}}:2 = transform.structured.tile + %0, %1:2 = transform.structured.tile %arg0 { sizes = [2, 0, 3] } } -// CHECK-LABEL: func @tile_linalg_matmul( -// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<128x128xf32> -// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor<128x128xf32> -// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor<128x128xf32> -// CHECK-SAME: -> tensor<128x128xf32> { -func.func @tile_linalg_matmul( - %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) - -> tensor<128x128xf32> { -// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<128x128xf32>) { -// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<128x128xf32>) { -// CHECK: %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<128x128xf32>) { -// CHECK: %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<128x128xf32> to tensor<4x4xf32> -// CHECK: %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<128x128xf32> to tensor<4x4xf32> -// CHECK: %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<128x128xf32> to tensor<4x4xf32> -// CHECK: %[[sTD:.*]] = linalg.matmul ins(%[[sTA]], %[[sTB]] : tensor<4x4xf32>, tensor<4x4xf32>) -// CHECK-SAME: outs(%[[sTC]] : tensor<4x4xf32>) -> tensor<4x4xf32> -// CHECK: %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}] : tensor<4x4xf32> into tensor<128x128xf32> -// CHECK: scf.yield %[[TD]] : tensor<128x128xf32> -// CHECK: scf.yield %[[TD2]] : tensor<128x128xf32> -// CHECK: scf.yield %[[TD1]] : tensor<128x128xf32> - %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) - outs(%arg2: tensor<128x128xf32>) - -> tensor<128x128xf32> +//===----------------------------------------------------------------------===// +// Check that operations are registered correctly through the extension +// mechanism. Their syntax is generated and requries no additional testing since +// we test the generator. +//===----------------------------------------------------------------------===// + +transform.sequence { +^bb1(%arg0: !pdl.operation): + // CHECK: transform.structured.pad + %0 = transform.structured.pad %arg0 +} -// CHECK: return %[[TD0]] : tensor<128x128xf32> - return %0 : tensor<128x128xf32> +transform.sequence { +^bb1(%arg0: !pdl.operation): + // CHECK: transform.structured.interchange + %0 = transform.structured.interchange %arg0 } +transform.sequence { +^bb1(%arg0: !pdl.operation): + // CHECK: transform.structured.scalarize + %0 = transform.structured.scalarize %arg0 +}