diff --git a/mlir/docs/Dialects/Transform.md b/mlir/docs/Dialects/Transform.md --- a/mlir/docs/Dialects/Transform.md +++ b/mlir/docs/Dialects/Transform.md @@ -1,23 +1,9 @@ # Transform Dialect -Fine-grain transformation control dialect. +Fine-grain transformation control dialect. See [../Tutorials/transform](tutorial) for more introductory information. [TOC] -## Disclaimer - -**This dialect is actively developed and may change frequently.** - -To decrease the maintenance burden and churn, please post a description of -the intended use case on the MLIR forum. A few in-tree use cases are -currently supported: - - - high-level transformations on "structured ops" (i.e. ops that operate on - chunks of data in a way that can be decomposed into operations on - smaller chunks of data and control flow) in Linalg, Tensor and Vector - dialects; - - loop transformations in the SCF dialect. - ## Overview This dialect provides operations that can be used to control transformation diff --git a/mlir/docs/Tutorials/_index.md b/mlir/docs/Tutorials/_index.md --- a/mlir/docs/Tutorials/_index.md +++ b/mlir/docs/Tutorials/_index.md @@ -2,3 +2,4 @@ This section contains multiple MLIR tutorials. See [Toy tutorial](toy) for an introduction to using MLIR infrastructure. +See [Transform dialect tutorial](transform) for an introduction to using and extending of MLIR's Transform dialect. diff --git a/mlir/docs/Tutorials/transform/Ch0.md b/mlir/docs/Tutorials/transform/Ch0.md new file mode 100644 --- /dev/null +++ b/mlir/docs/Tutorials/transform/Ch0.md @@ -0,0 +1,314 @@ +# Chapter 0: A Primer on “Structured” Linalg Operations + +Before starting the tutorial on the Transform dialect, let us take a brief look at the concept of Structured operations and its implementation in the Linalg dialect. Note that the Transform dialect does not require Structured operations and vice versa. The two co-evolved at the beginning of the Transform dialect, which makes the subset of transformations for Structured operations the most mature and most suitable for the tutorial. If you are already familiar with this concept, skip to Chapter 1. + +Structured code generation intends to preserve the structure of the computation for as long as necessary to enable transformations, up to and including the design of IR abstractions that support specific transformations. + +## Uniform Elementwise Extension + +Consider a simple scalar arithmetic addition operation in MLIR, which maps directly to a machine instruction on most architectures that support floating point operations: + + +```mlir +%2 = arith.addf %0, %1 : f32 +``` + +This operation can be easily extended to uniformly apply to elements of a 1D vector, which is also often available as an instruction of vector machines: + +```mlir +%2 = arith.addf %0, %1 : vector<8xf32> +``` + +Only a few modern instruction sets offer instructions for two- or more-dimensional vectors. In MLIR, however, it is possible to transparently extend the uniform elementwise application to vectors of arbitrary rank. + +```mlir +%2 = arith.addf %0, %1 : vector<8x4xf32> +%5 = arith.addf %3, %4 : vector<2x2x2x2x2x2x2xf32> +``` + +As you can notice, MLIR’s arithmetic operations on vectors preserve the structure of uniform elementwise application. This structure can be leveraged by the compiler, for example, to produce smaller-rank operations available on the target or to fuse multiplication and addition when such a fused instruction is available (which becomes complicated when there are a hundred of multiplications followed by a hundred of additions). + +## Reduction + +Sometimes it is necessary to add elements of a vector to obtain a scalar. Some platforms provide specific instructions for this operation, some others provide ones that can be combined to achieve the desired effect, such as addition of adjacent elements and element shuffle. + +The Vector dialect in MLIR defines an operation to explicitly denote a within-vector reduction: + +```mlir +%0 = vector.reduction , %0 : vector<8xf32> into f32 +``` + +When no support is available, such an operation can be transformed into a loop: + +```mlir +%c0 = arith.constant 0 : index +%c1 = arith.constant 1 : index +%c8 = arith.constant 8 : index +%init = arith.constant 0.0 : f32 +%result = scf.for %i = %c0 to %c8 step %c1 iter_args(%partial = %init) -> (f32) { + %element = vector.extractelement %0[%i : index] : vector<8xf32> + %updated = arith.addf %partial, %element : f32 + scf.yield %updated : f32 +} +``` + +Even when special instructions are available, it may still be desirable to use the loop form (with unrolling), depending on instruction latency and register pressure. Preserving the structure of the operation as a single reduction gives the compiler an understanding that a within-vector reduction is performed and, therefore, a choice in implementation. + +## Contraction + +Contraction is a generalization of reduction that multiplies elements from two vectors before adding them up. A simple “add” reduction can be thought of as a contraction where one of the vectors contains `1.0`, the neutral element of multiplication. Contractions offer even more flexibility to the compiler, and are represented as by a dedicated operation in MLIR: + +```mlir +// Neutral initializer for the addition. +%init = arith.constant 0.0 : f32 +// Neutral element of multiplication. +%ones = arith.constant dense<1.0> : vector<8xf32> +// Actual contraction. +%result = vector.contract { + indexing_maps = [affine_map<(i) -> (i)>, + affine_map<(i) -> (i)>, + affine_map<(i) -> ()>], + iterator_types = ["reduction"] +} %0, %ones, %init : vector<8xf32>, vector<8xf32> into f32 +``` + +Note the `affine_map` expressions indicating how vector elements are indexed. Their meaning is perhaps most evident when writing the loop form pseudo-code equivalent to this contraction: + +```mlir +for i in 0 to 8: + init += p0[i] * ones[i] +``` + +where both `%0` and `%ones` use the loop induction variable `i`, as noted on the right-hand side of the corresponding affine map, `(i) -> (i)`, and the `%init` does not, as reflected on the right-hand side of its affine map, `(i) -> ()`. + +Similarly to uniform elementwise extension, MLIR vector contractions are not limited to 1D cases. In the 2D+ case, one can additionally specify which of the vector dimensions are being reduced and which ones are being preserved. This can be achieved by using the `iterator_types` attribute that specifies, for each dimension, whether it is being reduced (`"reduction"`) or preserved (`"parallel"`). Consider the following 3D contraction that encodes a matrix-matrix multiplication: + +```mlir +%result = vector.contract { + indexing_maps = [affine_map<(i, j, k) -> (i, k)>, + affine_map<(i, j, k) -> (k, j)>, + affine_map<(i, j, k) -> (i, j)>], + iterator_types = ["parallel", "parallel", "reduction"] +} %lhs, %rhs, %init: vector<8x10xf32>, vector<10x16xf32> into vector<8x16xf32> +``` + +Looking at the indexing maps, it is easy to recognize the loop form: + +```mlir +for i in 0 to 8: + for j in 0 to 16: + for k in 0 to 10: + init[i, j] += lhs[i, k] * rhs[k, j] +``` + +Preserving this higher-level structure of a contraction makes it significantly easier for the compiler to recognize operations such as matrix multiplications and dot products and gives it freedom to produce lower-level operations that leverage most advanced instructions or even pre-generated microkernels. + +## Generic Operation on Memory + +Until now, we have been considering operations on vectors stored in virtual registers. A similar contraction abstraction can be defined in memory: + +```mlir +linalg.generic { + indexing_maps = [affine_map<(i, j, k) -> (i, k)>, + affine_map<(i, j, k) -> (k, j)>, + affine_map<(i, j, k) -> (i, j)>], + iterator_types = ["parallel", "parallel", "reduction"] +} ins(%lhs, %rhs : memref<8x10xf32>, memref<10x16xf32>) + outs(%init : memref<8x16xf32>) { +^bb0(%lhs_one: f32, %rhs_one: f32, %init_one: f32): + %0 = arith.mulf %lhs_one, %rhs_one : f32 + %1 = arith.addf %init_one, %0 : f32 + linalg.yield %1 : f32 +} +``` + +This looks more complicated, so let us unpack. The `indexing_maps` and `iterator_types` are _exactly_ the same as we have seen above for vector contractions. The operands are now split into two lists: + + +* `in` operands containing the buffers that are being only read by the operation; +* `out` operands that are being read and updated by the operation. + +This separation wasn’t necessary on vectors because, in MLIR, vectors are read-only (SSA or functional form) and operations mutating a vector are in fact producing a new one instead. + +Furthermore, the operation now contains a region that explicitly specifies the multiplication and the addition operations that were implicit in the contraction. Block arguments in the region correspond to individual elements read from the buffer: the first two correspond to the `in` operands and the last one corresponds to the `out` operand. The value yielded from the region is “written” to the `out` operand and is available as the last block argument for future executions of the region. Note that the order in which the region is executed for various tuples of elements read from the buffers is not specified, and the write to the `out` buffer is written as a whole at the end of the operation. + +## “Loop” Fusion + +Since the region of the `generic` operation can contain arbitrarily many operations, we can use it to express “fusion” of the implicit loops by simply having more operations chained in the region. For example, the common machine learning rectified linear unit layer (ReLU), which can be defined as `relu(x) = max(0, x)`, can be defined be expressed using the “compare-and-select” idiom in one `generic` operation, without the temporary buffer for the comparison result and without repeating the outer operation: + +```mlir +linalg.generic { + indexing_maps [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], + iterator_types = ["parallel"] +} ins(%in : memref) outs(%out : memref) { +^bb0(%in_one : f32, %out_one : f32): + %c0 = arith.constant 0.0 : f32 + %0 = arith.cmpf ogt %in_one, %c0 : f32 + %1 = arith.select %0, %in_one, %c0 : f32 + linalg.yield %1 : f32 +} +``` + +Such operations can be converted to loops or lowered into vector forms after splitting into multiple operations, each of which maps to a Vector dialect primitive. This modeling, again, gives the compiler more choice in selecting the code generation strategy. + +## Generic Operation on Tensors + +Let us take one last step up on the abstraction ladder. MLIR provides a tensor abstraction that makes it easy for the compiler to reason about multidimensional yet regular data without having to solve complex problems such as alias analysis and dependency satisfaction, which would be necessary on multidimensional buffers. The tensor abstraction is very similar to the vector abstraction (major differences include the availability of unranked tensors, tensor layouts, and vectors being usable as elemental types of tensors but not of other vectors). Tensors are read-only, and operations updating a tensor produce a new tensor. + +The `generic` operation from above can lifted to operate on tensors instead of buffers: + +```mlir +%result = linalg.generic { + indexing_maps = [affine_map<(i, j, k) -> (i, k)>, + affine_map<(i, j, k) -> (k, j)>, + affine_map<(i, j, k) -> (i, j)>], + iterator_types = ["parallel", "parallel", "reduction"] +} ins(%lhs, %rhs : tensor<8x10xf32>,tensor<10x16xf32>) + outs(%init :tensor<8x16xf32>) { +^bb0(%lhs_one: f32, %rhs_one: f32, %init_one: f32): + %0 = arith.mulf %lhs_one, %rhs_one : f32 + %1 = arith.addf %init_one, %0 : f32 + linalg.yield %1 : f32 +} -> tensor<8x16xf32> +``` + +As you can notice, most components of this operation remain identical to its buffer version. It has been specifically designed this way. The main difference, beside the operand types, is that the operation now produces a new result instead of updating the `out` buffer. The `out` operand is used only as the initialization value. + +If the `generic` operation had existed on vectors, it would have had the exact same structure. + +## Tiling and Loop Materialization + +At this level of abstraction, it becomes easy for the compiler to perform more advanced transformations usually required for high-performance code generation, such as [tiling](https://en.wikipedia.org/wiki/Loop_nest_optimization). Tiling, in general, can be seen as partitioning the iteration space into smaller parts, or tiles, so that the data required by each part fits into a level of cache for example. The order in which tiles are executed must preserve the original data dependencies. + +In the case of `generic` operations, the iteration space is implicit and is defined by the shape of the operands. Therefore, a tile can be expressed by performing the _same_ operation on a subset (slice) of the original data. Since the order in which the body of `generic` is applied to different tuples of the input elements is unspecified, tiles can be executed in any order, without the need for dependence analysis. In order to control the execution of different tiles, the implementation of tiling produces loops. Thus tiling `generic` operations can also be seen as materializing the loops that have been implicit until now. + +For example, tiling the matrix multiplication presented above with tile sizes `(2, 8)`, we obtain a loop nest around a `generic` expressing the same operation on a `2x8` tensor. + +```mlir +// A special "multi-for" loop that supports tensor-insertion semantics +// as opposed to implicit updates. The resulting 8x16 tensor will be produced +// by this loop. +// The trip count of iterators is computed dividing the original tensor size, +// 8x16, by the tile size, 2x8, to obtain 4x2. +// When tensor sizes are dynamic, the trip count computation is emitted as IR +// and is being computed at runtime. +%0 = scf.forall (%i, %j) in (4, 2) + shared_outs(%shared = %init) -> (tensor<8x16xf32>) { + + // Scale the loop induction variables by the tile sizes. + %3 = affine.apply affine_map<(d0) -> (d0 * 2)>(%i) + %4 = affine.apply affine_map<(d0) -> (d0 * 8)>(%j) + + // Take slices of inputs and outputs. Only the "i" and "j" dimensions are sliced. + %lhs_slice = tensor.extract_slice %lhs[%3, 0] [2, 10] [1, 1] + : tensor<8x10xf32> to tensor<2x10xf32> + %rhs_slice = tensor.extract_slice %rhs[0, %4] [10, 8] [1, 1] + : tensor<10x16xf32> to tensor<10x8xf32> + %result_slice = tensor.extract_slice %shared[%3, %4] [2, 8] [1, 1] + : tensor<8x16xf32> to tensor<2x8xf32> + + // This is exactly the same operation as before, but now operating on smaller + // slices of data. + %partial = linalg.generic { + indexing_maps = [affine_map<(i, j, k) -> (i, k)>, + affine_map<(i, j, k) -> (k, j)>, + affine_map<(i, j, k) -> (i, j)>], + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%lhs_slice, %rhs_slice : tensor<2x10xf32>, tensor<10x8xf32>) + outs(%result_slice : tensor<2x8xf32>) -> tensor<2x8xf32> { + ^bb0(%lhs_one: f32, %rhs_one: f32, %init_one: f32): + %0 = arith.mulf %lhs_one, %rhs_one : f32 + %1 = arith.addf %init_one, %0 : f32 + linalg.yield %1 : f32 + } : tensor<2x8xf32> + + // Terminator for the loop with tensor-insertion semantics. Inserts a slice + // into a larger tensor, potentially in parallel. + scf.forall.in_parallel { + tensor.parallel_insert_slice %partial into %shared[%3, %4] [2, 8] [1, 1] + : tensor<2x8xf32> into tensor<8x16xf32> + } +} +``` + +## Producer/Consumer Fusion and Rematerialization + +After materializing loops with tiling, another key code generation transformation becomes simple – fusion. Unlike loop fusion, the Structured operations approach allows for producer/consumer fusion even when the (implicit) iteration spaces of the operations do not match. Given an high-level structured operation on tensors, such as `linalg.generic`, one can follow use-def chains to identify: + +1. the subset (slice) of the operand that is used by the tile, and +2. the tensor-level structured operation producing the whole tensor that is being sliced. + +By inverting the `indexing_map` and applying it to the set of elements accessed through the slice, we can compute the part of the iteration space of the operation defining the full tensor necessary to compute the tile. Thus fusion boils down to replacing the `tensor.extract_slice` operation with the tile of the `linalg.generic` producing the original operand. + +Let us assume that the matrix multiplication operation is followed by another operation that multiplies each element of the resulting matrix with itself. This trailing elementwise operation has a 2D iteration space, unlike the 3D one in matrix multiplication. Nevertheless, it is possible to tile the trailing operation and then fuse the producer of its operand, the matmul, into the loop generated by tiling. The untiled dimension will be used in its entirety. + + +```mlir +// Same loop as before. +%0 = scf.forall (%i, %j) in (4, 2) + shared_outs(%shared = %init) + -> (tensor<8x16xf32>, tensor<8x16xf32>) { + // Scale the loop induction variables by the tile sizes. + %1 = affine.apply affine_map<(d0) -> (d0 * 2)>(%i) + %2 = affine.apply affine_map<(d0) -> (d0 * 8)>(%j) + + // Take slices of inputs and outputs. Only the "i" and "j" dimensions are sliced. + %lhs_slice = tensor.extract_slice %lhs[%1, 0] [2, 10] [1, 1] + : tensor<8x10xf32> to tensor<2x10xf32> + %rhs_slice = tensor.extract_slice %rhs[0, %2] [10, 8] [1, 1] + : tensor<10x16xf32> to tensor<10x8xf32> + %result_slice = tensor.extract_slice %result[%1, %2] [2, 8] [1, 1] + : tensor<8x16xf32> to tensor<2x8xf32> + + // This is exactly the same matmul slice as before. It replaces the slice + // extraction for the generic operation below. + %partial = linalg.generic { + indexing_maps = [affine_map<(i, j, k) -> (i, k)>, + affine_map<(i, j, k) -> (k, j)>, + affine_map<(i, j, k) -> (i, j)>], + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%lhs_slice, %rhs_slice : tensor<2x10xf32>, tensor<10x8xf32>) + outs(%result_slice : tensor<2x8xf32>) { + ^bb0(%lhs_one: f32, %rhs_one: f32, %init_one: f32): + %5 = arith.mulf %lhs_one, %rhs_one : f32 + %6 = arith.addf %init_one, %5 : f32 + linalg.yield %6 : f32 + } -> tensor<2x8xf32> + + // Take the slice of the final result. Note that we don't need to take + // the slice of the operand because the matmul operation above computes + // it in-place. + %shared_slice = tensor.extract_slice %shared[%1, %2] [2, 8] [1, 1] + : tensor<8x16xf32> to tensor<2x8xf32> + + // The elementwise operation that we tiled. + %elemwise = linalg.generic { + indexing_maps = [affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (i, j)>], + iterator_types = ["parallel", "parallel"] + } ins(%partial : tensor<2x8xf32>) + outs(%shared_slice : tensor<2x8xf32>) { + ^bb0(%in: f32, %out: f32): + %5 = arith.mulf %in, %in : f32 + linalg.yield %5 : f32 + } -> tensor<2x8xf32> + + // Terminator for the loop with tensor-insertion semantics. Inserts a slice + // into a larger tensor, potentially in parallel. + scf.forall.in_parallel { + tensor.parallel_insert_slice %elemwise into %shared[%1, %2] [2, 8] [1, 1] + : tensor<2x8xf32> into tensor<8x16xf32> + } +} +``` + +This process may result in some elements in the operand tensors being (re)computed on every iteration of the loop. This is also known as _rematerialization_ and expresses the tradeoff between performing redundant computations or storing their result in (slow) memory. + +## Shorthand “Named” Forms of Linalg Ops + +Linalg provides a set of predefined operations for common cases such as matrix multiplication, dot product, convolution, etc. These operations are equivalent to the `generic` ones but spare the need to spell out the access patterns and the bodies. For example, matrix multiplication is simply: + +```mlir +%matmul = linalg.matmul ins(%lhs, %rhs: tensor<8x10xf32>, tensor<10x16xf32>) + outs(%init: tensor<8x10xf32xf32>) -> tensor<8x16xf32> +``` diff --git a/mlir/docs/Tutorials/transform/Ch1.md b/mlir/docs/Tutorials/transform/Ch1.md new file mode 100644 --- /dev/null +++ b/mlir/docs/Tutorials/transform/Ch1.md @@ -0,0 +1,364 @@ +# Chapter 1: Combining Existing Transformations + +## Introduction + +The Transform dialect allows one to precisely target transformations at specific operations in the IR and to chain them, that is to apply a transformation to operations produced by the previous transformation. To achieve this, transformations are expressed as other operations in the IR. We call these the IR containing these operations transform IR. And we call the IR that is being transformed payload IR. + +Transform IR operations operate on values that may be associated with payload IR operations, values or attributes. We call the first two kinds of values operation and value handles, respectively. We call the last kind of values parameters. + +The application of transform IR always starts from one top-level operation. In the C++ API, this operation is passed to the `applyTransforms` function. This top-level operation specifies if other transformations should be performed and how. The most common top-level operation merely applies other transform operations listed in its body one after the other. + +Let us illustrate this with a simple sequence of transformations on the common “fully connected + bias + ReLU” ML layer, which boils down to performing a matrix multiplication, followed by an (elementwise) matrix addition and taking an elementwise maximum with 0. This can be expressed using the following IR: + +```mlir +func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { + // Matrix-matrix multiplication. + %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise addition. + %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise max with 0 (ReLU). + %c0f = arith.constant 0.0 : f32 + %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%biased, %c0f : tensor<512x512xf32>, f32) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + func.return %relued : tensor<512x512xf32> +} +``` + +## Top-Level Sequence Operation + +For performance reasons, we would like to tile and fuse these operations to exploit cache locality. This is a sequence of transformations that need to be performed one after another, so we naturally start with the corresponding top-level transform operation. + +```mlir +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op, + %arg1: !transform.op<"linalg.matmul">, + %arg2: !transform.op<"linalg.elemwise_binary">): + transform.yield +} +``` + +There are several aspects worth noticing in this operation. + +The first entry block argument is mandatory for top-level transform operations and is associated with the top-level payload operation that sequence is applied to, for example, a module or a function. This operation is specified when calling `applyTransforms`. + +The remaining entry block arguments are optional and can be associated with payload attributes, operations or values that are useful in the sequence. These are also specified when calling `applyTransforms`. In our case, we are interested in the matrix multiplication and elementwise operations that we are going to tile and fuse. + +All value handles have Transform dialect types. These types specify certain properties of the payload IR entities associated with them. In this example, `transform.any_op` indicates that the handle is associated with arbitrary payload operations. On the contrary, `transform.op<"X">` indicates that the handle is associated _only_ with payload operations of kind `X`. These constraints are verified when the handle/payload association is created. For entry block arguments of top-level transform operations, this happens early in the `applyTransforms` function. If the constraints are not satisfied, the transform application fails and produces diagnostics for the user. + +## Failure Propagation + +Speaking about diagnostics, the `sequence` operation itself has a mandatory attribute specifying the failure propagation mode. There are two options: + +* “propagate” makes the sequence transformation fail if any of the nested transformation fails; +* “suppress” makes the sequence succeed even if one of the nested transformations fails, but without attempting to perform the transformations following the failed one in the sequence. + +This latter allows the transformation to continue despite (recoverable) errors. As we are only building the transformation, it is preferable to propagate failures so we know when something did not apply. + +To check or debug a transform sequence, it is possible to print various entities associated with the transform IR values. For example, we can print the operations associated with the handles: + +```mlir +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op, + %arg1: !transform.op<"linalg.matmul">, + %arg2: !transform.op<"linalg.elemwise_binary">): + transform.test_print_remark_at_operand %matmul, "matmul" + : !transform.op<"linalg.matmul"> + transform.test_print_remark_at_operand %elemwise, "elemwise_binaries" + : !transform.op<"linalg.elemwise_binary"> + transform.yield +} +``` + +## Transform Dialect Interpreter + +Since we don’t want to recompile the compiler every time we change a transformation, we can use a transform dialect interpreter pass to apply this transformation sequence to the payload IR. As we will see in the next chapter, it is possible to define custom passes or even integrate the transform interpreter into a larger pass. For now, we can use the existing test pass: + + +```sh +$ mlir-opt matmul.mlir --pass-pipeline=" + builtin.module(test-transform-dialect-interpreter{ + bind-first-extra-to-ops=linalg.matmul + bind-second-extra-to-ops=linalg.elemwise_binary})" +``` + +The `matmul.mlir` file contains _both_ the payload IR function _and_ the transform IR sequence nested in the same module. The transform interpreter will find the first top-level transform operation in the root operation of the pass (the module in our case) and apply it to that root operation. In our case, we also asked the interpreter pass to associate the two extra arguments of the top-level sequence with all `linalg.matmul` and `linalg.elemwise_binary` payload operations through the respective pass options. Running this pass results in the expected remarks: + +```sh +matmul.mlir:7:13: remark: matmul + %matmul = linalg.matmul + ^ +matmul.mlir:7:13: note: see current operation: %0 = linalg.matmul ins(%arg0, %arg1 : tensor<512x512xf32>, tensor<512x512xf32>) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32> +matmul.mlir:10:13: remark: elemwise_binaries + %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } + ^ +matmul.mlir:10:13: note: see current operation: %1 = linalg.elemwise_binary {fun = #linalg.binary_fn} ins(%0, %arg2 : tensor<512x512xf32>, tensor<512x512xf32>) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32> +matmul.mlir:14:13: remark: elemwise_binaries + %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } + ^ +matmul.mlir:14:13: note: see current operation: %2 = linalg.elemwise_binary {fun = #linalg.binary_fn} ins(%1, %cst : tensor<512x512xf32>, f32) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32> +``` + +Note that `%arg2` is associated with both elementwise payload operations. Any handle is associated with a list of entities. Individual transformations may or may not care about the order of elements in that list. + + +## Specifying Transformations + +Now that we have handles to the operations we want to transform, we are ready to apply the transformations. Let us first try tiling the matmul operation itself. + +```mlir +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op, + %arg1: !transform.op<"linalg.matmul">, + %arg2: !transform.op<"linalg.elemwise_binary">): + // The actual tiling transformation takes tile sizes as attributes. + %loop, %tiled = transform.structured.tile_to_forall_op %arg1 tile_sizes [4, 32] + : (!transform.op<"linalg.matmul") -> (!transform.any_op, !transform.any_op) + transform.yield +} +``` + +The transformation returns two handles, as indicated in its [documentation](https://mlir.llvm.org/docs/Dialects/Transform/#transformstructuredtile_to_forall_op-mlirtransformtiletoforallop): + +* A handle to the `scf.forall` “multi-for” loop around tensors. +* A handle to `linalg.generic` operating on the subset of the original data. + +Running this transformation with the same command as above expectedly produces the tiled code. + +```mlir +func.func @fc_relu(%arg0: tensor<512x512xf32>, %arg1: tensor<512x512xf32>, %arg2: tensor<512x512xf32>, %arg3: tensor<512x512xf32>) -> tensor<512x512xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = scf.forall (%arg4, %arg5) in (128, 16) shared_outs(%arg6 = %arg3) -> (tensor<512x512xf32>) { + %3 = affine.apply affine_map<(d0) -> (d0 * 4)>(%arg4) + %4 = affine.apply affine_map<(d0) -> (d0 * 32)>(%arg5) + %extracted_slice = tensor.extract_slice %arg0[%3, 0] [4, 512] [1, 1] + : tensor<512x512xf32> to tensor<4x512xf32> + %extracted_slice_0 = tensor.extract_slice %arg1[0, %4] [512, 32] [1, 1] + : tensor<512x512xf32> to tensor<512x32xf32> + %extracted_slice_1 = tensor.extract_slice %arg6[%3, %4] [4, 32] [1, 1] + : tensor<512x512xf32> to tensor<4x32xf32> + %5 = linalg.matmul + ins(%extracted_slice, %extracted_slice_0 + : tensor<4x512xf32>, tensor<512x32xf32>) + outs(%extracted_slice_1 : tensor<4x32xf32>) -> tensor<4x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %5 into %arg6[%3, %4] [4, 32] [1, 1] + : tensor<4x32xf32> into tensor<512x512xf32> + } + } + %1 = linalg.elemwise_binary {fun = #linalg.binary_fn} + ins(%0, %arg2 : tensor<512x512xf32>, tensor<512x512xf32>) + outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32> + %2 = linalg.elemwise_binary {fun = #linalg.binary_fn} + ins(%1, %cst : tensor<512x512xf32>, f32) + outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32> + return %2 : tensor<512x512xf32> +} +``` + +Besides producing new handles, the tiling transform operation _consumes_ the operand handle. This means that the handle is _invalidated_ after this operation, and is no longer supposed to be used. Transform operations are required to mark all their operands as either consumed or readonly. Transform operations usually consume the operand if the associated payload operations are erased or recreated (which means erased and created anew with similar structure). As handles are essentially references to payload operations, they would become dangling if the payload no longer exists. + + +## Handle Invalidation and Expensive Checks Mode + +Undefined behavior is difficult to grapple with when it does happen, so the transform dialect interpreter provides a set of additional expensive checks that detect most undefined behavior in the transform IR. For example, if we wanted to use the `%arg1` handle after it is consumed, it would cause undefined behavior that manifests as an assertion in the debug build, and likely as a segmentation fault in the release mode. + +```mlir +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op, + %arg1: !transform.op<"linalg.matmul">, + %arg2: !transform.op<"linalg.elemwise_binary">): + // The actual tiling transformation takes tile sizes as attributes. + %loop, %tiled = transform.structured.tile_to_forall_op %arg1 tile_sizes [4, 32] + : (!transform.op<"linalg.matmul">) -> (!transform.any_op, !transform.any_op) + + // This is trying to use an invalidated handle leading to undefined behavior. + transform.test_print_remark_at_operand %arg1, "remark" : !transform.op<"linalg.matmul"> + transform.yield +} +``` + +However, with the expensive checks enabled in the interpreter, a nice diagnostic is produced: + +```sh +$ mlir-opt matmul.mlir --pass-pipeline=" + builtin.module(test-transform-dialect-interpreter{ + bind-first-extra-to-ops=linalg.matmul + bind-second-extra-to-ops=linalg.elemwise_binary, + enable-expensive-checks})" +``` + +```sh +matmul.mlir:28:3: error: op uses a handle invalidated by a previously executed transform op + transform.test_print_remark_at_operand %mm, "elemwise_binaries" : !transform.any_op + ^ +matmul.mlir:26:9: note: handle to invalidated ops + %mm = transform.cast %matmul : !transform.op<"linalg.matmul"> to !transform.any_op + ^ +matmul.mlir:27:19: note: invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them + %loop, %tiled = transform.structured.tile_to_forall_op %mm tile_sizes [4, 32] +``` + +One may observe that some operations such as `transform.cast` do not consume the operand (because they don’t erase the corresponding operation). So what would happen if we tried to use that operand instead? + +```mlir +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op, + %arg1: !transform.op<"linalg.matmul">, + %arg2: !transform.op<"linalg.elemwise_binary">): + // We can cast one type to another as long as operations are compatible + // with both types. This creates "aliasing" handles. + %casted = transform.cast %arg1 : !transform.op<"linalg.matmul"> + to !transform.any_op + + // The actual tiling transformation takes tile sizes as attributes. + %loop, %tiled = transform.structured.tile_to_forall_op %arg1 tile_sizes [4, 32] + : (!transform.op<"linalg.matmul">) -> (!transform.any_op, !transform.any_op) + + // Consuming an operand invalidates the consumed handle and any other handle that is + // associated with the same payload operations, or payload operations nested in them. + transform.test_print_remark_at_operand %casted, "remark" + : !transform.any_op + transform.yield +} +``` + +Both `%arg1` and `%casted` reference the same payload operation. Extending the reference analogy, these references alias. Naturally, when the payload operation is erased, all references to it become dangling. This is also the case for handles. In fact, consuming an operand invalidates the operand handle as well as any other handle that is associated with any of the same payload operations. The payload IR consideration is recursive: a handle associated with a payload operation _nested_ in the erased one is also invalidated (because erasing the operation also erases its regions and all contained operations). The expensive-checks mode can also handle this case. + +```sh +matmul.mlir:28:3: error: op uses a handle invalidated by a previously executed transform op + transform.test_print_remark_at_operand %matmul, "elemwise_binaries" : !transform.op<"linalg.matmul"> + ^ +matmul.mlir:21:29: note: handle to invalidated ops +^bb0(%root: !transform.any_op, %matmul: !transform.op<"linalg.matmul">, %elemwise: !transform.op<"linalg.elemwise_binary">): + ^ +matmul.mlir:27:19: note: invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them + %loop, %tiled = transform.structured.tile_to_forall_op %mm tile_sizes [4, 32] +``` + +## Chaining Transformations with Handles + +Going back to the transformation sequence, we have tiled the matrix multiplication, but we also want to tile and fuse the elementwise operations. The typical way of doing in the structured operations paradigm is to tile the last operation in some acyclic dataflow graph, and then progressively fuse the operations that produce its operands. This removes the need to explicitly tile all operations as fusion can adapt their sizes and inject recomputation if desired. So instead of tiling the matmul operation, we are going to tile the last operation in the chain, and then fuse the preceding operations into the loops produced by tiling. + +```mlir +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op, + %arg1: !transform.op<"linalg.matmul">, + %arg2: !transform.op<"linalg.elemwise_binary">): + // Since the %arg2 handle is associated with both elementwise operations, + // we need to split it into two handles so we can target only the second + // elementwise operation. + %add, %max = transform.split_handle %arg2 + : (!transform.op<"linalg.elemwise_binary">) + -> (!transform.any_op, !transform.any_op) + + // The actual tiling transformation takes tile sizes as attributes. It produces a + // handle to the loop generated during tiling. + %loop, %tiled = transform.structured.tile_to_forall_op %max tile_sizes [8, 32] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // We can now fuse the other operations into the loop. Here, we fuse + // operations one-by-one. This requires the operation that is being fused + // to define the value used within the loop, so the order of such fusions + // is important. We could also use "transform.merge_handles" to obtain + // a single handle to all operations and give it to `fuse_into_containing_op` + // that would take care of the ordering in this case. + %add_fused = transform.structured.fuse_into_containing_op %add into %loop + : (!transform.any_op, !transform.any_op) -> !transform.any_op + %matmul_fused = transform.structured.fuse_into_containing_op %arg1 into %loop + : (!transform.op<"linalg.matmul">, !transform.any_op) -> !transform.any_op + + transform.yield +} + +``` + +This achieves the desired tiling and fusion. + +## More Handle Invalidation + +Finally, let us assume there exists an efficient microkernel, or a hardware instruction expressed as an intrinsic function, for a 4x4 matrix multiplication. For this purpose, we need to tile the fused operation to the desired size, and then outline it. The resulting function call can then be replaced with a call to the microkernel. + +```mlir +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op, + %arg1: !transform.op<"linalg.matmul">, + %arg2: !transform.op<"linalg.elemwise_binary">): + // Since the %arg2 handle is associated with both elementwise operations, + // we need to split it into two handles so we can target only the second + // elementwise operation. + %add, %max = transform.split_handle %arg2 : (!transform.op<"linalg.elemwise_binary">) + -> (!transform.any_op, !transform.any_op) + + // The actual tiling transformation takes tile sizes as attributes. It produces a + // handle to the loop generated during tiling. + %loop, %tiled = transform.structured.tile_to_forall_op %max tile_sizes [8, 32] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // We can now fuse the other operations into the loop. Here, we fuse + // operations one-by-one. This requires the operation that is being fused + // to define the value used within the loop, so the order of such fusions + // is important. We could also use "transform.merge_handles" to obtain + // a single handle to all operations and give it to `fuse_into_containing_op` + // that would take care of the ordering in this case. + %add_fused = transform.structured.fuse_into_containing_op %add into %loop + : (!transform.any_op, !transform.any_op) -> !transform.any_op + %matmul_fused = transform.structured.fuse_into_containing_op %arg1 into %loop + : (!transform.op<"linalg.matmul">, !transform.any_op) -> !transform.any_op + + // Tile again to get the desired size. Note that this time this tiles the + // "add" operation and fuses matmul into the loop, but doesn't affect the + // "max" operation. This illustrates the precise targeting with the transform + // dialect. Otherwise, it is difficult to differentiate "add" and "max", both + // of which having the same kind. + %loop_2, %tiled_2 = transform.structured.tile_to_forall_op %add_fused tile_sizes [4, 4] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %matmul_fused_2 = transform.structured.fuse_into_containing_op %matmul_fused into %loop_2 + : (!transform.any_op, !transform.any_op) -> !transform.any_op + + // Since outlining is currently only implemented for region-holding operations + // such as loops, use tiling to size 1 to materialize the outer loop that is + // going to be outlined. + %outline_target, %_ = transform.structured.tile_to_forall_op %tiled_2 tile_sizes [1] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.structured.fuse_into_containing_op %matmul_fused_2 into %outline_target + : (!transform.any_op, !transform.any_op) -> !transform.any_op + %func, %call = transform.loop.outline %outline_target {func_name = "outlined"} + : (!transform.any_op) -> (!transform.any_op, !transform.op<"func.call">) + + transform.yield +} + +``` + +This additional transformation also illustrates handle invalidation for nested operations. The `transform.loop.outline` operation consumes the handle to the loop, which invalidates it and all handles to any operations nested in it, such as `%2`. Attempting to use this handle will cause undefined behavior. (Note that it isn’t strictly necessary for this specific form of the outlining to consume the operand as the implementation only _moves_ the region without recreating the operations, but the author of the transformation chose to invalidate the handle anyway.) + +Attempting to access the fusion result after outlining produces the following error + +```sh +test/Examples/transform/Ch1/invalidation-2.mlir:109:3: error: op uses a handle invalidated by a previously executed transform op + transform.test_print_remark_at_operand %outline_target, "outlined loop" : !transform.any_op + ^ +test/Examples/transform/Ch1/invalidation-2.mlir:102:25: note: handle to invalidated ops + %outline_target, %_ = transform.structured.tile_to_forall_op %tiled_2 tile_sizes [1] + ^ +test/Examples/transform/Ch1/invalidation-2.mlir:106:18: note: invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them + %func, %call = transform.loop.outline %outline_target {func_name = "outlined"} + ^ +test/Examples/transform/Ch1/invalidation-2.mlir:24:13: note: ancestor payload op + %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } + ^ +test/Examples/transform/Ch1/invalidation-2.mlir:24:13: note: nested payload op + %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) +``` + +Note that the “add” elementwise operation is indicated as payload ancestor because it was used to produce the tile loop, and the loop therefore has its location. + +Finally, we would like to replace the call to the outlined function with a call to the microkernel. Unfortunately, the Transform dialect doesn’t have support for this transformation (and cannot have if the call is rewritten to a custom, out-of-tree operation). Therefore, we need to define new transform operations. The next chapters will describe how this can be done. diff --git a/mlir/docs/Tutorials/transform/Ch2.md b/mlir/docs/Tutorials/transform/Ch2.md new file mode 100644 --- /dev/null +++ b/mlir/docs/Tutorials/transform/Ch2.md @@ -0,0 +1,327 @@ +# Chapter 2: Adding a Simple New Transformation Operation + +## Setting Up to Add New Transformations + +Before defining a new transform operation, we need to choose where its implementation should be located. While MLIR encourages upstream contributions, it is not always possible or even desirable to modify the main Transform dialect, for example, if the transformation is specific to some out-of-tree dialect that is not itself available upstream. + +The Transform dialect uses the dialect extension mechanism to allow additional operations to be injected without modifying the dialect itself. Dialect extensions are registered with the context and loaded when the dialect itself is loaded. Extension definition is straightforward: + +```cpp +// In MyExtension.cpp. +#include "mlir/Dialect/Transform/IR/TransformDialect.h" + +// Define a new transform dialect extension. This uses the CRTP idiom to identify +// extensions. +class MyExtension : public ::mlir::transform::TransformDialectExtension { +public: + // The extension must derive the base constructor. + using Base::Base; + + // This function initializes the extension, similarly to `initialize` in dialect + // definitions. List individual operations and dependent dialects here. + void init(); +}; + +void MyExtension::init() { + // Similarly to dialects, an extension can declare a dependent dialect. This dialect + // will be loaded along with the extension and, therefore, along with the Transform + // dialect. Only declare as dependent the dialects that contain the attributes or + // types used by transform operations. Do NOT declare as dependent the dialects + // produced during the transformation. + // declareDependentDialect(); + + // When transformations are applied, they may produce new operations from previously + // unloaded dialects. Typically, a pass would need to declare itself dependent on + // the dialects containing such new operations. To avoid confusion with the dialects + // the extension itself depends on, the Transform dialects differentiates between: + // - dependent dialects, which are used by the transform operations, and + // - generated dialects, which contain the entities (attributes, operations, + // types) that may be produced by applying the transformation even when not + // present in the original payload IR. + // In the following chapter, we will be add operations that generate function calls + // and structured control flow operations, so let's declare the corresponding + // dialects as generated. + declareGeneratedDialect<::mlir::scf::SCFDialect>(); + declareGeneratedDialect<::mlir::func::FuncDialect>(); + + // Finally, we register the additional transform operations with the dialect. + registerTransformOps< + // TODO: list the operation classes. + >(); +} +``` + +The operations themselves can be defined using ODS, exactly in the same way as regular operations in a dialect. + +```tablegen +// In MyExtension.td +#ifndef MY_EXTENSION +#define MY_EXTENSION + +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +def MyOp : Op { + let summary = "my transform op"; + // TODO: define the operation properties. +} + +#endif // MY_EXTENSION +``` + +Similarly to dialects, we must use Tablegen to generate the header and implementation of these operations. We can instruct CMake to do it as follows. + + +```sh +# In CMakeLists.txt next to MyExtension.td. + +# Tell Tablegen to use MyExtension.td as input. +set(LLVM_TARGET_DEFINITIONS MyExtension.td) + +# Ask Tablegen to generate op declarations and definitions from ODS. +mlir_tablegen(MyExtension.h.inc -gen-op-decls) +mlir_tablegen(MyExtension.cpp.inc -gen-op-defs) + +# Add a CMakeTarget we can depend on to ensure the generation happens before the compilation. +add_public_tablegen_target(MyExtensionIncGen) + +# Don't forget to generate the documentation, this will produce a MyExtension.md under +# Dialects. +add_mlir_doc(MyExtension MyExtension Dialects/ -gen-op-doc) +``` + +```sh +# In CMakeLists.txt next to MyExtension.cpp +add_mlir_library( + # Library called MyExtension. + MyExtension + + # Built from the following source files. + MyExtension.cpp + + # Make sure ODS declaration and definitions are generated before compiling this. + DEPENDS + MyExtensionIncGen + + # Link in the transform dialect, and all generated dialects. + LINK_LIBS PUBLIC + MLIRTransformDialect + MLIRFuncDialect + MLIRSCFDialect +) +``` + +This will generate two files, `MyExtension.h.inc` and `MyExtension.cpp.inc`, that are supposed to be included into the declaration and definition of the transform operations, respectively. + +```c++ +// In MyExtension.h. +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" + +#define GET_OP_CLASSES +#include "MyExtension.h.inc" +``` + +```c++ +// In MyExtension.cpp. + +#define GET_OP_CLASSES +#include "MyExtension.cpp.inc" + +// … +void MyExtension::init() { + // … + + // Finally, we register the additional transform operations with the dialect. List all + // operations generated from ODS. This call will perform additional checks that the + // operations implement the transform and memory effect interfaces required by the + // dialect interpreter and assert if they do not. + registerTransformOps< +#define GET_OP_LIST +#include "MyExtension.cpp.inc" + >(); +} +``` + +## Defining a Transform Operation + +With this setup, we are now ready to define the new transform operation to rewrite the function call. This is identical to defining a regular operation in a dialect. Note that the Transform dialect requires operations to implement the `TransformOpInterface` as well as `MemoryEffectsOpInterface` to indicate whether the operands are consumed or only read. Our operation can be defined along the following lines. + +```tablegen +// In MyExtension.td. + +// Define the new operation. By convention, prefix its name with the name of the dialect +// extension, "my.". The full operation name will be further prefixed with "transform.". +def ChangeCallTargetOp : Op, + DeclareOpInterfaceMethods]> { + // Provide a brief and a full description. It is recommended that the latter describes + // the effects on the operands and how the operation processes various failure modes. + let summary = "Changes the callee of a call operation to the specified one"; + let description = [{ + For each `func.call` payload operation associated with the handle, changes its + callee to be the symbol whose name is provided as an attribute to this operation. + + Generates a silenceable failure if the operand is associated with payload operations + that are not `func.call`. + Only reads the operand. + }]; + + // The arguments include the handle to the payload operations and the attribute that + // specifies the new callee. The handle must implement TransformHandleTypeInterface. + // We use a string attribute as the symbol may not exist in the transform IR so the + // verification may fail. + let arguments = (ins + TransformHandleTypeInterface:$call, + StrAttr:$new_target); + + // The results are empty as the transformation does not produce any new payload. + let results = (outs); + + // Provide nice syntax. + let assemblyFormat = "$call `,` $new_target attr-dict `:` type($call)"; +} +``` + +To finalize the definition of the transform operation, we need to implement the interface methods. The `TransformOpInterface` currently requires only one method – `apply` – that performs the actual transformation. It is a good practice to limit the body of the method to manipulation of the Transform dialect constructs and have the actual transformation implemented as a standalone function so it can be used from other places in the code. + + +```c++ +// In MyExtension.cpp + +// Implementation of our transform dialect operation. +// This operation returns a tri-state result that can be one of: +// - success when the transformation succeeded; +// - definite failure when the transformation failed in such a way that following +// transformations are impossible or undesirable, typically it could have left payload +// IR in an invalid state; it is expected that a diagnostic is emitted immediately +// before returning the definite error; +// - silenceable failure when the transformation failed but following transformations +// are still applicable, typically this means a precondition for the transformation is +// not satisfied and the payload IR has not been modified. +// The silenceable failure additionally carries a Diagnostic that can be emitted to the +// user. +::mlir::DiagnosedSilenceableFailure ChangeCallTargetOp::apply( + // The list of payload IR entities that will be associated with the transform IR + // values defined by this transform operation. In this case, it can remain empty as + // there are no results. + ::mlir::transform::TransformResults &results, + // The transform application state. This object can be used to query the current + // associations between transform IR values and payload IR entities. It can also + // carry additional user-defined state. + ::mlir::transform::TransformState &state) { + + // First, we need to obtain the list of payload operations that are associated with + // the operand handle. + auto payload = state.getPayloadOps(getCall()); + + // Then, we iterate over the list of operands and call the actual IR-mutating + // function. We also check the preconditions here. + for (Operation *payloadOp : payload) { + auto call = dyn_cast<::mlir::func::CallOp>(payloadOp); + if (!call) { + DiagnosedSilenceableFailure diag = emitSilenceableError() + << "only applies to func.call payloads"; + diag.attachNote(payloadOp->getLoc()) << "offending payload"; + return diag; + } + + updateCallee(call, getNewTarget()); + } + + // If everything went well, return success. + return DiagnosedSilenceableFailure::success(); +} +``` + +The implementation of the `MemoryEffectsOpInterface` must specify the effects this operation has on its operands (consumed or readonly) and on the payload IR (mutates or readonly). Transform dialect verifiers will check for side effects being present and assert in debug builds if they are not. + +```c++ +// In MyExtension.cpp + +void ChangeCallTargetOp::getEffects( + ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) { + // Indicate that the `call` handle is only read by this operation because the + // associated operation is not erased but rather modified in-place, so the + // reference to it remains valid. + onlyReadsHandle(getCall(), effects); + + // Indicate that the payload is modified by this operation. + modifiesPayload(effects); +} +``` + +## Registration and Usage + +This is enough to define transform operations. The only remaining bit is providing the extension registration hook that can be called from the project’s `main`. + + +```c++ +// In TransformDialect.cpp (don't forget a declaration in TransformDialect.h); + +void registerMyExtension(::mlir::DialectRegistry ®istry) { + registry.addExtensions(); +} +``` + +After registering the extension, it becomes possible to use our new operation in the transform dialect interpreter. The upstream testing pass can be used as is. + +```mlir +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op, + %arg1: !transform.op<"linalg.matmul">, + %arg2: !transform.op<"linalg.elemwise_binary">): + // Since the %arg2 handle is associated with both elementwise operations, + // we need to split it into two handles so we can target only the second + // elementwise operation. + %add, %max = transform.split_handle %arg2 : (!transform.op<"linalg.elemwise_binary">) + -> (!transform.any_op, !transform.any_op) + + // The actual tiling transformation takes tile sizes as attributes. It produces a + // handle to the loop generated during tiling. + %loop, %tiled = transform.structured.tile_to_forall_op %max tile_sizes [8, 32] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // We can now fuse the other operations into the loop. Here, we fuse + // operations one-by-one. This requires the operation that is being fused + // to define the value used within the loop, so the order of such fusions + // is important. We could also use "transform.merge_handles" to obtain + // a single handle to all operations and give it to `fuse_into_containing_op` + // that would take care of the ordering in this case. + %add_fused = transform.structured.fuse_into_containing_op %add into %loop + : (!transform.any_op, !transform.any_op) -> !transform.any_op + %matmul_fused = transform.structured.fuse_into_containing_op %arg1 into %loop + : (!transform.op<"linalg.matmul">, !transform.any_op) -> !transform.any_op + + // Tile again to get the desired size. Note that this time this tiles the + // "add" operation and fuses matmul into the loop, but doesn't affect the + // "max" operation. This illustrates the precise targeting with the transform + // dialect. Otherwise, it is difficult to differentiate "add" and "max", both + // of which having the same kind. + %loop_2, %tiled_2 = transform.structured.tile_to_forall_op %add_fused tile_sizes [4, 4] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %matmul_fused_2 = transform.structured.fuse_into_containing_op %matmul_fused into %loop_2 + : (!transform.any_op, !transform.any_op) -> !transform.any_op + + // Since outlining is currently only implemented for region-holding operations + // such as loops, use tiling to size 1 to materialize the outer loop that is + // going to be outlined. + %outline_target, %_ = transform.structured.tile_to_forall_op %tiled_2 tile_sizes [1] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.structured.fuse_into_containing_op %matmul_fused_2 into %outline_target + : (!transform.any_op, !transform.any_op) -> !transform.any_op + %func, %call = transform.loop.outline %outline_target {func_name = "outlined"} + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Rewrite the call target. + transform.my.change_call_target %call, "microkernel" : !transform.any_op + + transform.yield +} +``` diff --git a/mlir/docs/Tutorials/transform/Ch3.md b/mlir/docs/Tutorials/transform/Ch3.md new file mode 100644 --- /dev/null +++ b/mlir/docs/Tutorials/transform/Ch3.md @@ -0,0 +1,283 @@ +# Chapter 3: More than Simple Transform Operations + +## Type Constraints and ApplyEach Trait + +A transform operation that applies to each payload operation individually and requires it to be of a specific kind is a repeated pattern. One can use Transform dialect types to specify the preconditions of the type. Specifically, we can change the expected operand type from the wide `TransformHandleTypeInterface` to the more narrow `Transform_ConcreteOp<"func.call">`. Furthermore, we use the `TransformEachOpTrait` trait to provide the skeleton implementation of the `apply` method that performs verification, iteration over payloads and result concatenation. The improved ODS op definition is as follows. + +```tablegen +// In MyExtension.td. + +// Define the new operation. By convention, prefix its name with the name of the dialect extension, "my.". The full operation name will be further prefixed with "transform.". +def ChangeCallTargetOp : Op]> { + // Provide a brief and a full description. It is recommended that the latter describes + // the effects on the operands and how the operation processes various failure modes. + let summary = "Changes the callee of a call operation to the specified one"; + let description = [{ + For each `func.call` payload operation associated with the handle, changes its + callee to be the symbol whose name is provided as an attribute to this operation. + + Generates a silenceable failure if the operand is associated with payload operations + that are not `func.call`. + Only reads the operand. + }]; + + // The arguments include the handle to the payload operations and the attribute that + // specifies the new callee. The handle must implement TransformHandleTypeInterface. + // We use a string attribute as the symbol may not exist in the transform IR so the + // verification may fail. + let arguments = (ins + Transform_ConcreteOpType<"func.call">:$call, + StrAttr:$new_target); + + // The results are empty as the transformation does not produce any new payload. + let results = (outs); + + // Provide nice syntax. + let assemblyFormat = "$call `,` $new_target attr-dict `:` type($call)"; + + // Declare the function implementing the interface for a single payload operation. + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::func::CallOp call, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} +``` + +Now, instead of defining the `apply` method with a loop, we can simply define a function that applies to an individual payload operation and the trait will take care of the rest. + +```c++ +::mlir::DiagnosedSilenceableFailure ChangeCallTargetOp::applyToOne( + ::mlir::func::CallOp call,, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // Call the actual transformation function. + updateCallee(call, getNewTarget()); + // Indicate success. + return DiagnosedSilenceableFailure::success(); +} +``` + +## Defining a Transform Type + +In addition to operations, the Transform dialect allows extensions to define and inject additional attributes and types. As we have seen above, transform types are used to specify constraints on the payload operations. Our call rewriting operation currently applies only to `func.call`. We may want to generalize it to apply to any payload operation that implements `CallOpInterface`, but the Transform dialect currently doesn’t provide a type that checks if a payload operation implements this interface. Let’s define it in our extension. + +Type definition is again identical to defining a dialect type with ODS. + +```tablegen +// Transform dialect allows additional types to be defined and injected. +def CallOpInterfaceHandle + : TypeDef]> { + + // The usual components of a type such as description, mnemonic and assembly format + // should be provided. + let summary = "handle to payload operations implementing CallOpInterface"; + let mnemonic = "my.call_op_interface"; + let assemblyFormat = ""; +} +``` + +We will omit the generation of declaration and definitions using Tablegen for brevity as it is identical to the regular case. + +To finalize the definition of a transform type, one must implement the interface methods. + +```c++ +// In MyExtension.cpp. + +// The interface declares this method to verify constraints this type has on +// payload operations. It returns the now familiar tri-state result. +mlir::DiagnosedSilenceableFailure +mlir::transform::CallOpInterfaceHandleType::checkPayload( + // Location at which diagnostics should be emitted. + mlir::Location loc, + // List of payload operations that are about to be associated with the + // handle that has this type. + llvm::ArrayRef payload) const { + + // All payload operations are expected to implement CallOpInterface, check this. + for (Operation *op : payload) { + if (llvm::isa(op)) + continue; + + // By convention, these verifiers always emit a silenceable failure since they are + // checking a precondition. + DiagnosedSilenceableFailure diag = emitSilenceableError(loc) + << "expected the payload operation to implement CallOpInterface"; + diag.attachNote(op->getLoc()) << "offending operation"; + return diag; + } + + // If everything is okay, return success. + return DiagnosedSilenceableFailure::success(); +} + +``` + +Additional attributes and types need to be registered in the extension, next to operations. + +```c++ +// In MyExtension.cpp. + +void MyExtension::init() { + // … + + registerTypes< +#define GET_TYPEDEF_LIST +#include "MyExtensionTypes.cpp.inc" + >(); +} +``` + +This type is now directly available in the transform dialect and can be used in operations. + + +```mlir + // Cast to our new type. + %casted = transform.cast %call : !transform.any_op to !transform.my.call_op_interface + // Using our new operation. + transform.my.change_call_target %casted, "microkernel" : !transform.my.call_op_interface +``` + +## Operand Consumption + +As an exercise, let us modify the rewriting operation to consume the operand. This would be necessary, for example, if the transformation were to rewrite the `func.call` operation into a custom operation `my.mm4`. Since the operand handle is now consumed, the operation can return a new handle to the newly produced payload operation. Otherwise, the ODS definition of the transform operation remains unchanged. + + +```tablegen +// In MyExtension.td. + +// Define another transform operation. +def CallToOp : Op]> { + // Summary and description omitted for brevity. + + // The argument is the handle to the payload operations. + let arguments = (ins CallOpInterfaceHandle:$call); + + // The result is the handle to the payload operations produced during the + // transformation. + let results = (outs TransformHandleTypeInterface:$transformed); + + // Provide nice syntax. + let assemblyFormat = "$call attr-dict `:` functional-type(inputs, outputs)"; + + // Declare the function implementing the interface for a single payload operation. + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::CallOpInterface call, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} +``` + +Now let’s look at the implementation of interface methods. + +```c++ +// In MyExtension.cpp. + +::mlir::DiagnosedSilenceableFailure CallToOp::applyToOne( + ::mlir::CallOpInterface call, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // Call the actual rewrite. + Operation *rewritten = rewriteToOp(call); + + // Report an error if the rewriter produced a null pointer. Note that it may have + // irreversibly modified the payload IR, so we produce a definite failure. + if (!rewritten) { + return emitDefiniteError() << "failed to rewrite call to operation"; + } + + // On success, push the resulting operation into the result list. The list is expected + // to contain exactly one entity per result and per application. The handles will be + // associated with lists of the respective values produced by each application. + results.push_back(rewritten); + + // If everything is fine, return success. + return DiagnosedSilenceableFailure::success(); +} + +void CallToOp::getEffects( + ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) { + // Indicate using side effects that the operand handle is consumed, and the + // result handle is produced. + consumesHandle(getCall(), effects); + producesHandle(getTransformed(), effects); + + // Indicate that the payload IR is modified. + modifiesPayload(effects); +} +``` + +The overall flow of these implementations is similar to the previous one. The application also needs to specify the resulting entities that are going to be associated with the handles it produces. Operations are required to specify the entities to associate with _all_ results on success, even if the list is empty. An assertion will be triggered if it is not the case. In case of failure, the interpreter will automatically associate all results that are not yet defined with empty lists. + +Note that since `applyToOne` always expects one payload entity to be associated with each result handle in each application, it cannot be used to return handles associated with empty lists for non-empty operand handles. Instead, one would use `apply` directly. + +```c++ +::mlir::DiagnosedSilenceableFailure SomeOtherOp::apply( + ::mlir::transform::TransformResults &results, + ::mlir::transform::TransformState &state) { + // ... + + // Associate the result `transformed` with an empty list of payload operations. + results.set(cast(getTransformed()), {}); + return DiagnosedSilenceableFailure::success(); +} +``` + +## Memory Effects Traits + +Some common memory effect patterns are also available as traits to minimize the boilerplate. + +* `FunctionalStyleTransformOpTrait` indicates that all handle-typed operands are consumed, all results are produced, and the payload IR is modified. +* `NavigationTransformOpTrait` indicates that all handle-typed operands are only read, all results are produced, and the payload IR is only read. + +Using these traits removes the need to declare or define the methods of the `MemoryEffectsOpInterface`. + +```tablegen +// In MyExtension.td. + +// Define another transform operation. +def CallToOp : Op { + // Summary and description omitted for brevity. + + // The argument is the handle to the payload operations. + let arguments = (ins CallOpInterfaceHandle:$call); + + // The result is the handle to the payload operations produced during the + // transformation. + let results = (outs TransformHandleTypeInterface:$transformed); + + // Provide nice syntax. + let assemblyFormat = "$call attr-dict `:` functional-type(operands, results)"; + + // Declare the function implementing the interface for a single payload operation. + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::CallOpInterface call, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} +``` + + diff --git a/mlir/docs/Tutorials/transform/_index.md b/mlir/docs/Tutorials/transform/_index.md new file mode 100644 --- /dev/null +++ b/mlir/docs/Tutorials/transform/_index.md @@ -0,0 +1,32 @@ +# Transform Dialect Tutorial + +MLIR supports declarative specification for controlling compiler transformations +via the transform dialect. It allows one to request compiler transformations +using compiler IR itself, which can be embedded into the original IR that is +being transformed (similarly to pragmas) or supplied separately (similarly to +scheduling languages). This tutorial presents the concepts of the MLIR transform +dialect and related infrastructure. It will be accompanied by a practical +demonstration of three use scenarios: + +- Composing transform dialect operations available in (upstream) MLIR to perform + a sequence of optimizing transformations that results in efficient code for an + MLIR linear algebra operation. +- Defining new transform dialect operations and adapting existing transformation + code to work with the transform dialect infrastructure. +- Setting up and using the transform dialect infrastructure in a downstream + out-of-tree project with custom dialects, transformations and passes. + +After following the tutorial, one will be able to apply the transform dialect in +their work and extend it when necessary. Basic familiarity with MLIR is a +prerequisite. See [Toy tutorial](../Toy) for introduction to MLIR. + +The tutorial is divided into the following chapters. + +- [Chapter #0](Ch0.md): A Primer on “Structured” Linalg Operations +- [Chapter #1](Ch1.md): Combining Existing Transformations +- [Chapter #2](Ch2.md): Adding a Simple New Transformation Operation +- [Chapter #3](Ch3.md): More than Simple Transform Operations + +The code corresponding to this tutorial is located under +`mlir/Examples/transform` and the corresponding tests in +`mlir/test/Examples/transform`. diff --git a/mlir/examples/CMakeLists.txt b/mlir/examples/CMakeLists.txt --- a/mlir/examples/CMakeLists.txt +++ b/mlir/examples/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(toy) +add_subdirectory(transform) diff --git a/mlir/examples/transform/CMakeLists.txt b/mlir/examples/transform/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/examples/transform/CMakeLists.txt @@ -0,0 +1,4 @@ +add_custom_target(TransformExample) + +add_subdirectory(Ch2) +add_subdirectory(Ch3) diff --git a/mlir/examples/transform/Ch2/CMakeLists.txt b/mlir/examples/transform/Ch2/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/examples/transform/Ch2/CMakeLists.txt @@ -0,0 +1,20 @@ +# For a better top-level template to copy, see examples/standalone. + +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) + +add_subdirectory(include) +add_subdirectory(lib) + +add_dependencies(TransformExample transform-opt-ch2) +add_llvm_example(transform-opt-ch2 + transform-opt/transform-opt.cpp) + +target_link_libraries(transform-opt-ch2 + PRIVATE + MLIRIR + MLIRMlirOptMain + MLIRSideEffectInterfaces + MyExtensionCh2 +) diff --git a/mlir/examples/transform/Ch2/include/CMakeLists.txt b/mlir/examples/transform/Ch2/include/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/examples/transform/Ch2/include/CMakeLists.txt @@ -0,0 +1,12 @@ +# Tell Tablegen to use MyExtension.td as input. +set(LLVM_TARGET_DEFINITIONS MyExtension.td) + +# Ask Tablegen to generate op declarations and definitions from ODS. +mlir_tablegen(MyExtension.h.inc -gen-op-decls) +mlir_tablegen(MyExtension.cpp.inc -gen-op-defs) + +# Add a CMakeTarget we can depend on to ensure the generation happens before the compilation. +add_public_tablegen_target(MyExtensionCh2IncGen) + +# Don't forget to generate the documentation, this will produce a MyExtension.md under Dialects. +add_mlir_doc(MyExtension MyExtensionCh2 Dialects/ -gen-op-doc) diff --git a/mlir/examples/transform/Ch2/include/MyExtension.h b/mlir/examples/transform/Ch2/include/MyExtension.h new file mode 100644 --- /dev/null +++ b/mlir/examples/transform/Ch2/include/MyExtension.h @@ -0,0 +1,21 @@ +//===-- MyExtension.h - Transform dialect tutorial --------------*- c++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines Transform dialect extension operations used in the +// Chapter 2 of the Transform dialect tutorial. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" + +#define GET_OP_CLASSES +#include "MyExtension.h.inc" + +// Registers our Transform dialect extension. +void registerMyExtension(::mlir::DialectRegistry ®istry); diff --git a/mlir/examples/transform/Ch2/include/MyExtension.td b/mlir/examples/transform/Ch2/include/MyExtension.td new file mode 100644 --- /dev/null +++ b/mlir/examples/transform/Ch2/include/MyExtension.td @@ -0,0 +1,56 @@ +//===-- MyExtension.td - Transform dialect tutorial --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines Transform dialect extension operations used in the +// Chapter 2 of the Transform dialect tutorial. +// +//===----------------------------------------------------------------------===// + +#ifndef MY_EXTENSION +#define MY_EXTENSION + +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +// Define the new operation. By convention, prefix its name with the name of the dialect +// extension, "my.". The full operation name will be further prefixed with "transform.". +def ChangeCallTargetOp : Op, + DeclareOpInterfaceMethods]> { + // Provide a brief and a full description. It is recommended that the latter describes + // the effects on the operands and how the operation processes various failure modes. + let summary = "Changes the callee of a call operation to the specified one"; + let description = [{ + For each `func.call` payload operation associated with the handle, changes its + callee to be the symbol whose name is provided as an attribute to this operation. + + Generates a silenceable failure if the operand is associated with payload operations + that are not `func.call`. + Only reads the operand. + }]; + + // The arguments include the handle to the payload operations and the attribute that + // specifies the new callee. The handle must implement TransformHandleTypeInterface. + // We use a string attribute as the symbol may not exist in the transform IR so the + // verification may fail. + let arguments = (ins + TransformHandleTypeInterface:$call, + StrAttr:$new_target); + + // The results are empty as the transformation does not produce any new payload. + let results = (outs); + + // Provide nice syntax. + let assemblyFormat = "$call `,` $new_target attr-dict `:` type($call)"; +} + +#endif // MY_EXTENSION diff --git a/mlir/examples/transform/Ch2/lib/CMakeLists.txt b/mlir/examples/transform/Ch2/lib/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/examples/transform/Ch2/lib/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_library( + # Library called MyExtension. + MyExtensionCh2 + + # Built from the following source files. + MyExtension.cpp + + # Make sure ODS declaration and definitions are generated before compiling this. + DEPENDS + MyExtensionCh2IncGen + + # Link in the transform dialect, an all generated dialects. + LINK_LIBS PUBLIC + MLIRTransformDialect + MLIRFuncDialect + MLIRSCFDialect +) diff --git a/mlir/examples/transform/Ch2/lib/MyExtension.cpp b/mlir/examples/transform/Ch2/lib/MyExtension.cpp new file mode 100644 --- /dev/null +++ b/mlir/examples/transform/Ch2/lib/MyExtension.cpp @@ -0,0 +1,132 @@ +//===-- MyExtension.cpp - Transform dialect tutorial ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines Transform dialect extension operations used in the +// Chapter 2 of the Transform dialect tutorial. +// +//===----------------------------------------------------------------------===// + +#include "MyExtension.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" + +// Define a new transform dialect extension. This uses the CRTP idiom to +// identify extensions. +class MyExtension + : public ::mlir::transform::TransformDialectExtension { +public: + // The extension must derive the base constructor. + using Base::Base; + + // This function initializes the extension, similarly to `initialize` in + // dialect definitions. List individual operations and dependent dialects + // here. + void init(); +}; + +void MyExtension::init() { + // Similarly to dialects, an extension can declare a dependent dialect. This + // dialect will be loaded along with the extension and, therefore, along with + // the Transform dialect. Only declare as dependent the dialects that contain + // the attributes or types used by transform operations. Do NOT declare as + // dependent the dialects produced during the transformation. + // declareDependentDialect(); + + // When transformations are applied, they may produce new operations from + // previously unloaded dialects. Typically, a pass would need to declare + // itself dependent on the dialects containing such new operations. To avoid + // confusion with the dialects the extension itself depends on, the Transform + // dialects differentiates between: + // - dependent dialects, which are used by the transform operations, and + // - generated dialects, which contain the entities (attributes, operations, + // types) that may be produced by applying the transformation even when + // not present in the original payload IR. + // In the following chapter, we will be add operations that generate function + // calls and structured control flow operations, so let's declare the + // corresponding dialects as generated. + declareGeneratedDialect<::mlir::scf::SCFDialect>(); + declareGeneratedDialect<::mlir::func::FuncDialect>(); + + // Finally, we register the additional transform operations with the dialect. + // List all operations generated from ODS. This call will perform additional + // checks that the operations implement the transform and memory effect + // interfaces required by the dialect interpreter and assert if they do not. + registerTransformOps< +#define GET_OP_LIST +#include "MyExtension.cpp.inc" + >(); +} + +#define GET_OP_CLASSES +#include "MyExtension.cpp.inc" + +static void updateCallee(mlir::func::CallOp call, llvm::StringRef newTarget) { + call.setCallee(newTarget); +} + +// Implementation of our transform dialect operation. +// This operation returns a tri-state result that can be one of: +// - success when the transformation succeeded; +// - definite failure when the transformation failed in such a way that +// following +// transformations are impossible or undesirable, typically it could have left +// payload IR in an invalid state; it is expected that a diagnostic is emitted +// immediately before returning the definite error; +// - silenceable failure when the transformation failed but following +// transformations +// are still applicable, typically this means a precondition for the +// transformation is not satisfied and the payload IR has not been modified. +// The silenceable failure additionally carries a Diagnostic that can be emitted +// to the user. +::mlir::DiagnosedSilenceableFailure mlir::transform::ChangeCallTargetOp::apply( + // The list of payload IR entities that will be associated with the + // transform IR values defined by this transform operation. In this case, it + // can remain empty as there are no results. + ::mlir::transform::TransformResults &results, + // The transform application state. This object can be used to query the + // current associations between transform IR values and payload IR entities. + // It can also carry additional user-defined state. + ::mlir::transform::TransformState &state) { + + // First, we need to obtain the list of payload operations that are associated + // with the operand handle. + auto payload = state.getPayloadOps(getCall()); + + // Then, we iterate over the list of operands and call the actual IR-mutating + // function. We also check the preconditions here. + for (Operation *payloadOp : payload) { + auto call = dyn_cast<::mlir::func::CallOp>(payloadOp); + if (!call) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() << "only applies to func.call payloads"; + diag.attachNote(payloadOp->getLoc()) << "offending payload"; + return diag; + } + + updateCallee(call, getNewTarget()); + } + + // If everything went well, return success. + return DiagnosedSilenceableFailure::success(); +} + +void mlir::transform::ChangeCallTargetOp::getEffects( + ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) { + // Indicate that the `call` handle is only read by this operation because the + // associated operation is not erased but rather modified in-place, so the + // reference to it remains valid. + onlyReadsHandle(getCall(), effects); + + // Indicate that the payload is modified by this operation. + modifiesPayload(effects); +} + +void registerMyExtension(::mlir::DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/mlir/examples/transform/Ch2/transform-opt/transform-opt.cpp b/mlir/examples/transform/Ch2/transform-opt/transform-opt.cpp new file mode 100644 --- /dev/null +++ b/mlir/examples/transform/Ch2/transform-opt/transform-opt.cpp @@ -0,0 +1,61 @@ +//===-- transform-opt.cpp - Transform dialect tutorial entry point --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the top-level file for the Transform dialect tutorial chapter 2. +// +//===----------------------------------------------------------------------===// + +#include "MyExtension.h" + +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/InitAllDialects.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "mlir/Transforms/Passes.h" +#include + +// Forward declarations of test passes that used in this chapter for +// illustrative purposes. Test passes are not directly exposed for use in +// binaries other than mlir-opt, which is too big to serve as an example. +namespace mlir::test { +void registerTestTransformDialectEraseSchedulePass(); +void registerTestTransformDialectInterpreterPass(); +} // namespace mlir::test + +namespace test { +void registerTestTransformDialectExtension(mlir::DialectRegistry &); +} // namespace test + +int main(int argc, char **argv) { + // Register all "core" dialects and our transform dialect extension. + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + registerMyExtension(registry); + + // Register a handful of cleanup passes that we can run to make the output IR + // look nicer. + mlir::registerCanonicalizerPass(); + mlir::registerCSEPass(); + mlir::registerSymbolDCEPass(); + + // Register the test passes. +#ifdef MLIR_INCLUDE_TESTS + mlir::test::registerTestTransformDialectEraseSchedulePass(); + mlir::test::registerTestTransformDialectInterpreterPass(); + test::registerTestTransformDialectExtension(registry); +#else + llvm::errs() << "warning: MLIR built without test passes, interpreter " + "testing will not be available\n"; +#endif // MLIR_INCLUDE_TESTS + + // Delegate to the MLIR utility for parsing and pass management. + return mlir::MlirOptMain(argc, argv, "transform-opt-ch2", registry) + .succeeded() + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/mlir/examples/transform/Ch3/CMakeLists.txt b/mlir/examples/transform/Ch3/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/examples/transform/Ch3/CMakeLists.txt @@ -0,0 +1,20 @@ +# For a better top-level template to copy, see examples/standalone. + +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) + +add_subdirectory(include) +add_subdirectory(lib) + +add_dependencies(TransformExample transform-opt-ch3) +add_llvm_example(transform-opt-ch3 + transform-opt/transform-opt.cpp) + +target_link_libraries(transform-opt-ch3 + PRIVATE + MLIRIR + MLIRMlirOptMain + MLIRSideEffectInterfaces + MyExtensionCh3 +) diff --git a/mlir/examples/transform/Ch3/include/CMakeLists.txt b/mlir/examples/transform/Ch3/include/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/examples/transform/Ch3/include/CMakeLists.txt @@ -0,0 +1,19 @@ +# Tell Tablegen to use MyExtension.td as input. +set(LLVM_TARGET_DEFINITIONS MyExtension.td) + +# Ask Tablegen to generate op declarations and definitions from ODS. +mlir_tablegen(MyExtension.h.inc -gen-op-decls) +mlir_tablegen(MyExtension.cpp.inc -gen-op-defs) + +# Tell Tablegen to use MyExtensionTypes.td as input. +set(LLVM_TARGET_DEFINITIONS MyExtensionTypes.td) + +# Ask Tablegen to generate type declarations and definitions from ODS. +mlir_tablegen(MyExtensionTypes.h.cpp -gen-typedef-decls) +mlir_tablegen(MyExtensionTypes.cpp.inc -gen-typedef-defs) + +# Add a CMakeTarget we can depend on to ensure the generation happens before the compilation. +add_public_tablegen_target(MyExtensionCh3IncGen) + +# Don't forget to generate the documentation, this will produce a MyExtension.md under Dialects. +add_mlir_doc(MyExtension MyExtensionCh3 Dialects/ -gen-op-doc) diff --git a/mlir/examples/transform/Ch3/include/MyExtension.h b/mlir/examples/transform/Ch3/include/MyExtension.h new file mode 100644 --- /dev/null +++ b/mlir/examples/transform/Ch3/include/MyExtension.h @@ -0,0 +1,31 @@ +//===-- MyExtension.h - Transform dialect tutorial --------------*- c++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines Transform dialect extension operations used in the +// Chapter 3 of the Transform dialect tutorial. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" + +namespace mlir { +class CallOpInterface; +namespace func { +class CallOp; +} // namespace func +} // namespace mlir + +#define GET_TYPEDEF_CLASSES +#include "MyExtensionTypes.h.inc" + +#define GET_OP_CLASSES +#include "MyExtension.h.inc" + +// Registers our Transform dialect extension. +void registerMyExtension(::mlir::DialectRegistry ®istry); diff --git a/mlir/examples/transform/Ch3/include/MyExtension.td b/mlir/examples/transform/Ch3/include/MyExtension.td new file mode 100644 --- /dev/null +++ b/mlir/examples/transform/Ch3/include/MyExtension.td @@ -0,0 +1,98 @@ +//===-- MyExtension.td - Transform dialect tutorial --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines Transform dialect extension operations used in the +// Chapter 3 of the Transform dialect tutorial. +// +//===----------------------------------------------------------------------===// + +#ifndef MY_EXTENSION +#define MY_EXTENSION + +include "MyExtensionTypes.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +// Define the new operation. By convention, prefix its name with the name of the dialect +// extension, "my.". The full operation name will be further prefixed with "transform.". +def ChangeCallTargetOp : Op]> { + // Provide a brief and a full description. It is recommended that the latter describes + // the effects on the operands and how the operation processes various failure modes. + let summary = "Changes the callee of a call operation to the specified one"; + let description = [{ + For each `func.call` payload operation associated with the handle, changes its + callee to be the symbol whose name is provided as an attribute to this operation. + + Generates a silenceable failure if the operand is associated with payload operations + that are not `func.call`. + Only reads the operand. + }]; + + // The arguments include the handle to the payload operations and the attribute that + // specifies the new callee. The handle must implement TransformHandleTypeInterface. + // We use a string attribute as the symbol may not exist in the transform IR so the + // verification may fail. + let arguments = (ins + // Specify the type constraint on the input accepting only `func.call` payload + // operations. + Transform_ConcreteOpType<"func.call">:$call, + StrAttr:$new_target); + + // The results are empty as the transformation does not produce any new payload. + let results = (outs); + + // Provide nice syntax. + let assemblyFormat = "$call `,` $new_target attr-dict `:` qualified(type($call))"; + + // Declare the function implementing the interface for a single payload operation. + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::func::CallOp call, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// Define another transform operation. +def CallToOp : Op { + // Summary and description omitted for brevity. + + // The argument is the handle to the payload operations. + let arguments = (ins CallOpInterfaceHandle:$call); + + // The result is the handle to the payload operations produced during the + // transformation. + let results = (outs TransformHandleTypeInterface:$transformed); + + // Provide nice syntax. + let assemblyFormat = "$call attr-dict `:` functional-type(operands, results)"; + + // Declare the function implementing the interface for a single payload operation. + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::CallOpInterface call, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +#endif // MY_EXTENSION diff --git a/mlir/examples/transform/Ch3/include/MyExtensionTypes.td b/mlir/examples/transform/Ch3/include/MyExtensionTypes.td new file mode 100644 --- /dev/null +++ b/mlir/examples/transform/Ch3/include/MyExtensionTypes.td @@ -0,0 +1,34 @@ +//===-- MyExtensionTypes.td - Transform dialect tutorial ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines Transform dialect extension types used in the +// Chapter 3 of the Transform dialect tutorial. +// +//===----------------------------------------------------------------------===// + +#ifndef MY_EXTENSIONTYPES +#define MY_EXTENSIONTYPES + +include "mlir/IR/AttrTypeBase.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" + +// Transform dialect allows additional types to be defined and injected. +def CallOpInterfaceHandle + : TypeDef]> { + + // The usual components of a type such as description, mnemonic and assembly format + // should be provided. + let summary = "handle to payload operations implementing CallOpInterface"; + let mnemonic = "my.call_op_interface"; + let assemblyFormat = ""; +} + +#endif // MY_EXTENSIONTYPES diff --git a/mlir/examples/transform/Ch3/lib/CMakeLists.txt b/mlir/examples/transform/Ch3/lib/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/examples/transform/Ch3/lib/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_library( + # Library called MyExtension. + MyExtensionCh3 + + # Built from the following source files. + MyExtension.cpp + + # Make sure ODS declaration and definitions are generated before compiling this. + DEPENDS + MyExtensionCh3IncGen + + # Link in the transform dialect, an all generated dialects. + LINK_LIBS PUBLIC + MLIRTransformDialect + MLIRFuncDialect + MLIRSCFDialect +) diff --git a/mlir/examples/transform/Ch3/lib/MyExtension.cpp b/mlir/examples/transform/Ch3/lib/MyExtension.cpp new file mode 100644 --- /dev/null +++ b/mlir/examples/transform/Ch3/lib/MyExtension.cpp @@ -0,0 +1,218 @@ +//===-- MyExtension.cpp - Transform dialect tutorial ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines Transform dialect extension operations used in the +// Chapter 3 of the Transform dialect tutorial. +// +//===----------------------------------------------------------------------===// + +#include "MyExtension.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "llvm/ADT/TypeSwitch.h" + +#define GET_TYPEDEF_CLASSES +#include "MyExtensionTypes.cpp.inc" + +#define GET_OP_CLASSES +#include "MyExtension.cpp.inc" + +//===---------------------------------------------------------------------===// +// MyExtension +//===---------------------------------------------------------------------===// + +// Define a new transform dialect extension. This uses the CRTP idiom to +// identify extensions. +class MyExtension + : public ::mlir::transform::TransformDialectExtension { +public: + // The extension must derive the base constructor. + using Base::Base; + + // This function initializes the extension, similarly to `initialize` in + // dialect definitions. List individual operations and dependent dialects + // here. + void init(); +}; + +void MyExtension::init() { + // Similarly to dialects, an extension can declare a dependent dialect. This + // dialect will be loaded along with the extension and, therefore, along with + // the Transform dialect. Only declare as dependent the dialects that contain + // the attributes or types used by transform operations. Do NOT declare as + // dependent the dialects produced during the transformation. + // declareDependentDialect(); + + // When transformations are applied, they may produce new operations from + // previously unloaded dialects. Typically, a pass would need to declare + // itself dependent on the dialects containing such new operations. To avoid + // confusion with the dialects the extension itself depends on, the Transform + // dialects differentiates between: + // - dependent dialects, which are used by the transform operations, and + // - generated dialects, which contain the entities (attributes, operations, + // types) that may be produced by applying the transformation even when + // not present in the original payload IR. + // In the following chapter, we will be add operations that generate function + // calls and structured control flow operations, so let's declare the + // corresponding dialects as generated. + declareGeneratedDialect<::mlir::scf::SCFDialect>(); + declareGeneratedDialect<::mlir::func::FuncDialect>(); + + // Register the additional transform dialect types with the dialect. List all + // types generated from ODS. + registerTypes< +#define GET_TYPEDEF_LIST +#include "MyExtensionTypes.cpp.inc" + >(); + + // ODS generates these helpers for type printing and parsing, but the + // Transform dialect provides its own support for types supplied by the + // extension. Reference these functions to avoid a compiler warning. + (void)generatedTypeParser; + (void)generatedTypePrinter; + + // Finally, we register the additional transform operations with the dialect. + // List all operations generated from ODS. This call will perform additional + // checks that the operations implement the transform and memory effect + // interfaces required by the dialect interpreter and assert if they do not. + registerTransformOps< +#define GET_OP_LIST +#include "MyExtension.cpp.inc" + >(); +} + +//===---------------------------------------------------------------------===// +// ChangeCallTargetOp +//===---------------------------------------------------------------------===// + +static void updateCallee(mlir::func::CallOp call, llvm::StringRef newTarget) { + call.setCallee(newTarget); +} + +// Implementation of our transform dialect operation. +// This operation returns a tri-state result that can be one of: +// - success when the transformation succeeded; +// - definite failure when the transformation failed in such a way that +// following +// transformations are impossible or undesirable, typically it could have left +// payload IR in an invalid state; it is expected that a diagnostic is emitted +// immediately before returning the definite error; +// - silenceable failure when the transformation failed but following +// transformations +// are still applicable, typically this means a precondition for the +// transformation is not satisfied and the payload IR has not been modified. +// The silenceable failure additionally carries a Diagnostic that can be emitted +// to the user. +::mlir::DiagnosedSilenceableFailure +mlir::transform::ChangeCallTargetOp::applyToOne( + // The single payload operation to which the transformation is applied. + ::mlir::func::CallOp call, + // The payload IR entities that will be appended to lists associated with + // the results of this transform operation. This list contains one entry per + // result. + ::mlir::transform::ApplyToEachResultList &results, + // The transform application state. This object can be used to query the + // current associations between transform IR values and payload IR entities. + // It can also carry additional user-defined state. + ::mlir::transform::TransformState &state) { + + // Dispatch to the actual transformation. + updateCallee(call, getNewTarget()); + + // If everything went well, return success. + return DiagnosedSilenceableFailure::success(); +} + +void mlir::transform::ChangeCallTargetOp::getEffects( + ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) { + // Indicate that the `call` handle is only read by this operation because the + // associated operation is not erased but rather modified in-place, so the + // reference to it remains valid. + onlyReadsHandle(getCall(), effects); + + // Indicate that the payload is modified by this operation. + modifiesPayload(effects); +} + +//===---------------------------------------------------------------------===// +// CallToOp +//===---------------------------------------------------------------------===// + +static mlir::Operation *replaceCallWithOp(mlir::CallOpInterface call) { + // Construct an operation from an unregistered dialect. This is discouraged + // and is only used here for brevity of the overall example. + mlir::OperationState state(call.getLoc(), "my.mm4"); + state.types.assign(call->result_type_begin(), call->result_type_end()); + state.operands.assign(call->operand_begin(), call->operand_end()); + + mlir::OpBuilder builder(call); + mlir::Operation *replacement = builder.create(state); + call->replaceAllUsesWith(replacement->getResults()); + call->erase(); + return replacement; +} + +// See above for the signature description. +mlir::DiagnosedSilenceableFailure mlir::transform::CallToOp::applyToOne( + mlir::CallOpInterface call, mlir::transform::ApplyToEachResultList &results, + mlir::transform::TransformState &state) { + + // Dispatch to the actual transformation. + Operation *replacement = replaceCallWithOp(call); + + // Associate the payload operation produced by the rewrite with the result + // handle of this transform operation. + results.push_back(replacement); + + // If everything went well, return success. + return DiagnosedSilenceableFailure::success(); +} + +//===---------------------------------------------------------------------===// +// CallOpInterfaceHandleType +//===---------------------------------------------------------------------===// + +// The interface declares this method to verify constraints this type has on +// payload operations. It returns the now familiar tri-state result. +mlir::DiagnosedSilenceableFailure +mlir::transform::CallOpInterfaceHandleType::checkPayload( + // Location at which diagnostics should be emitted. + mlir::Location loc, + // List of payload operations that are about to be associated with the + // handle that has this type. + llvm::ArrayRef payload) const { + + // All payload operations are expected to implement CallOpInterface, check + // this. + for (Operation *op : payload) { + if (llvm::isa(op)) + continue; + + // By convention, these verifiers always emit a silenceable failure since + // they are checking a precondition. + DiagnosedSilenceableFailure diag = + emitSilenceableError(loc) + << "expected the payload operation to implement CallOpInterface"; + diag.attachNote(op->getLoc()) << "offending operation"; + return diag; + } + + // If everything is okay, return success. + return DiagnosedSilenceableFailure::success(); +} + +//===---------------------------------------------------------------------===// +// Extension registration +//===---------------------------------------------------------------------===// + +void registerMyExtension(::mlir::DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/mlir/examples/transform/Ch3/transform-opt/transform-opt.cpp b/mlir/examples/transform/Ch3/transform-opt/transform-opt.cpp new file mode 100644 --- /dev/null +++ b/mlir/examples/transform/Ch3/transform-opt/transform-opt.cpp @@ -0,0 +1,61 @@ +//===-- transform-opt.cpp - Transform dialect tutorial entry point --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the top-level file for the Transform dialect tutorial chapter 2. +// +//===----------------------------------------------------------------------===// + +#include "MyExtension.h" + +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/InitAllDialects.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "mlir/Transforms/Passes.h" +#include + +// Forward declarations of test passes that used in this chapter for +// illustrative purposes. Test passes are not directly exposed for use in +// binaries other than mlir-opt, which is too big to serve as an example. +namespace mlir::test { +void registerTestTransformDialectEraseSchedulePass(); +void registerTestTransformDialectInterpreterPass(); +} // namespace mlir::test + +namespace test { +void registerTestTransformDialectExtension(mlir::DialectRegistry &); +} // namespace test + +int main(int argc, char **argv) { + // Register all "core" dialects and our transform dialect extension. + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + registerMyExtension(registry); + + // Register a handful of cleanup passes that we can run to make the output IR + // look nicer. + mlir::registerCanonicalizerPass(); + mlir::registerCSEPass(); + mlir::registerSymbolDCEPass(); + + // Register the test passes. +#ifdef MLIR_INCLUDE_TESTS + mlir::test::registerTestTransformDialectEraseSchedulePass(); + mlir::test::registerTestTransformDialectInterpreterPass(); + test::registerTestTransformDialectExtension(registry); +#else + llvm::errs() << "warning: MLIR built without test passes, interpreter " + "testing will not be available\n"; +#endif // MLIR_INCLUDE_TESTS + + // Delegate to the MLIR utility for parsing and pass management. + return mlir::MlirOptMain(argc, argv, "transform-opt-ch3", registry) + .succeeded() + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/mlir/examples/transform/README.md b/mlir/examples/transform/README.md new file mode 100644 --- /dev/null +++ b/mlir/examples/transform/README.md @@ -0,0 +1,4 @@ +Transform Dialect Tutorial is available at +https://mlir.llvm.org/docs/Tutorials/Transform. + +Test files are located under `mlir/test/Examples/Transform`. diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -125,6 +125,8 @@ toyc-ch3 toyc-ch4 toyc-ch5 + transform-opt-ch2 + transform-opt-ch3 ) if(MLIR_ENABLE_EXECUTION_ENGINE) list(APPEND MLIR_TEST_DEPENDS diff --git a/mlir/test/Examples/transform/Ch1/invalidation-1.mlir b/mlir/test/Examples/transform/Ch1/invalidation-1.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Examples/transform/Ch1/invalidation-1.mlir @@ -0,0 +1,98 @@ +// RUN: mlir-opt %s \ +// RUN: --pass-pipeline="builtin.module(test-transform-dialect-interpreter{ \ +// RUN: bind-first-extra-to-ops=linalg.matmul \ +// RUN: bind-second-extra-to-ops=linalg.elemwise_binary \ +// RUN: enable-expensive-checks},canonicalize,cse,symbol-dce)" \ +// RUN: --split-input-file --verify-diagnostics + +// ****************************** IMPORTANT NOTE ****************************** +// +// If you are changing this file, you may also need to change +// mlir/docs/Tutorials/Transform accordingly. +// +// **************************************************************************** + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op, + // expected-note @below {{handle to invalidated ops}} + %arg1: !transform.op<"linalg.matmul">, + %arg2: !transform.op<"linalg.elemwise_binary">): + // The actual tiling transformation takes tile sizes as attributes. + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}} + %loop, %tiled = transform.structured.tile_to_forall_op %arg1 tile_sizes [4, 32] + : (!transform.op<"linalg.matmul">) -> (!transform.any_op, !transform.any_op) + + // This is trying to use an invalidated handle leading to undefined behavior. + // expected-error @below {{uses a handle invalidated by a previously executed transform op}} + transform.test_print_remark_at_operand %arg1, "remark" : !transform.op<"linalg.matmul"> + transform.yield +} + +// Original function to optimize. +func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { + // Matrix-matrix multiplication. + // expected-note @below {{payload op}} + %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise addition. + %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise max with 0 (ReLU). + %c0f = arith.constant 0.0 : f32 + %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%biased, %c0f : tensor<512x512xf32>, f32) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + func.return %relued : tensor<512x512xf32> +} + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op, + %arg1: !transform.op<"linalg.matmul">, + %arg2: !transform.op<"linalg.elemwise_binary">): + // We can cast one type to another as long as operations are compatible + // with both types. This creates "aliasing" handles. + // expected-note @below {{handle to invalidated ops}} + %casted = transform.cast %arg1 : !transform.op<"linalg.matmul"> to + !transform.any_op + + // The actual tiling transformation takes tile sizes as attributes. + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}} + %loop, %tiled = transform.structured.tile_to_forall_op %arg1 tile_sizes [4, 32] + : (!transform.op<"linalg.matmul">) -> (!transform.any_op, !transform.any_op) + + // Consuming an operand invalidates the consumed handle and any other handle that is + // associated with the same payload operations, or payload operations nested in them. + // expected-error @below {{uses a handle invalidated by a previously executed transform op}} + transform.test_print_remark_at_operand %casted, "remark" + : !transform.any_op + transform.yield +} + +// Original function to optimize. +func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { + // Matrix-matrix multiplication. + // expected-note @below {{payload op}} + %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise addition. + %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise max with 0 (ReLU). + %c0f = arith.constant 0.0 : f32 + %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%biased, %c0f : tensor<512x512xf32>, f32) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + func.return %relued : tensor<512x512xf32> +} diff --git a/mlir/test/Examples/transform/Ch1/invalidation-2.mlir b/mlir/test/Examples/transform/Ch1/invalidation-2.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Examples/transform/Ch1/invalidation-2.mlir @@ -0,0 +1,102 @@ +// RUN: mlir-opt %s \ +// RUN: --pass-pipeline="builtin.module(test-transform-dialect-interpreter{ \ +// RUN: bind-first-extra-to-ops=linalg.matmul \ +// RUN: bind-second-extra-to-ops=linalg.elemwise_binary \ +// RUN: enable-expensive-checks},canonicalize,cse,symbol-dce)" \ +// RUN: --split-input-file --verify-diagnostics + +// ****************************** IMPORTANT NOTE ****************************** +// +// If you are changing this file, you may also need to change +// mlir/docs/Tutorials/Transform accordingly. +// +// **************************************************************************** + +// Original function to optimize. +func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { + // Matrix-matrix multiplication. + + // expected-note @below {{nested payload op}} + %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise addition. + + // expected-note @below {{ancestor payload op}} + %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise max with 0 (ReLU). + %c0f = arith.constant 0.0 : f32 + %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%biased, %c0f : tensor<512x512xf32>, f32) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + func.return %relued : tensor<512x512xf32> +} + +// Declaration of the "microkernel" function that we will be targeting. +func.func private @microkernel( + %lhs: tensor<4x512xf32>, + %rhs: tensor<512x4xf32>, + %bias: tensor<4x4xf32>, + %init: tensor<4x4xf32>, + %output: tensor<4x4xf32>) -> tensor<4x4xf32> + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op, + %arg1: !transform.op<"linalg.matmul">, + %arg2: !transform.op<"linalg.elemwise_binary">): + // Since the %arg2 handle is associated with both elementwise operations, + // we need to split it into two handles so we can target only the second + // elementwise operation. + %add, %max = transform.split_handle %arg2 : (!transform.op<"linalg.elemwise_binary">) + -> (!transform.any_op, !transform.any_op) + + // The actual tiling transformation takes tile sizes as attributes. It produces a + // handle to the loop generated during tiling. + %loop, %tiled = transform.structured.tile_to_forall_op %max tile_sizes [8, 32] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // We can now fuse the other operations into the loop. Here, we fuse + // operations one-by-one. This requires the operation that is being fused + // to define the value used within the loop, so the order of such fusions + // is important. We could also use "transform.merge_handles" to obtain + // a single handle to all operations and give it to `fuse_into_containing_op` + // that would take care of the ordering in this case. + %add_fused, %loop2 = transform.structured.fuse_into_containing_op %add into %loop + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %matmul_fused, %loop3 = transform.structured.fuse_into_containing_op %arg1 into %loop2 + : (!transform.op<"linalg.matmul">, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Tile again to get the desired size. Note that this time this tiles the + // "add" operation and fuses matmul into the loop, but doesn't affect the + // "max" operation. This illustrates the precise targeting with the transform + // dialect. Otherwise, it is difficult to differentiate "add" and "max", both + // of which having the same kind. + %loop_second, %tiled_second = transform.structured.tile_to_forall_op %add_fused tile_sizes [4, 4] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %matmul_fused_2, %loop_second_2 = + transform.structured.fuse_into_containing_op %matmul_fused into %loop_second + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Since outlining is currently only implemented for region-holding operations + // such as loops, use tiling to size 1 to materialize the outer loop that is + // going to be outlined. + %loop_third, %_0 = transform.structured.tile_to_forall_op %tiled_second tile_sizes [1] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + // expected-note @below {{handle to invalidated ops}} + %f, %outline_target = transform.structured.fuse_into_containing_op %matmul_fused_2 into %loop_third + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}} + %func, %call = transform.loop.outline %outline_target {func_name = "outlined"} + : (!transform.any_op) -> (!transform.any_op, !transform.op<"func.call">) + + // expected-error @below {{uses a handle invalidated by a previously executed transform op}} + transform.test_print_remark_at_operand %f, "fused" : !transform.any_op + + transform.yield +} diff --git a/mlir/test/Examples/transform/Ch1/sequence.mlir b/mlir/test/Examples/transform/Ch1/sequence.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Examples/transform/Ch1/sequence.mlir @@ -0,0 +1,111 @@ +// RUN: mlir-opt %s \ +// RUN: --pass-pipeline="builtin.module(test-transform-dialect-interpreter{ \ +// RUN: bind-first-extra-to-ops=linalg.matmul \ +// RUN: bind-second-extra-to-ops=linalg.elemwise_binary \ +// RUN: enable-expensive-checks},canonicalize,cse,symbol-dce)" |\ +// RUN: FileCheck %s + +// ****************************** IMPORTANT NOTE ****************************** +// +// If you are changing this file, you may also need to change +// mlir/docs/Tutorials/Transform accordingly. +// +// **************************************************************************** + +// Original function to optimize. +func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { + // Matrix-matrix multiplication. + %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise addition. + %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise max with 0 (ReLU). + %c0f = arith.constant 0.0 : f32 + %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%biased, %c0f : tensor<512x512xf32>, f32) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + func.return %relued : tensor<512x512xf32> +} + +// CHECK: func @outlined +// CHECK: linalg.matmul +// CHECK: linalg.elemwise_binary {fun = #linalg.binary_fn} + +// CHECK-LABEL: func @fc_relu +// CHECK: scf.forall +// CHECK: scf.forall +// CHECK: %[[SLICE4:.+]] = tensor.extract_slice +// CHECK: %[[SLICE5:.+]] = tensor.extract_slice +// CHECK: %[[SLICE6:.+]] = tensor.extract_slice +// CHECK: %[[SLICE7:.+]] = tensor.extract_slice +// CHECK: %[[SLICE8:.+]] = tensor.extract_slice +// CHECK: func.call @outlined(%[[SLICE4]], %[[SLICE5]], %[[SLICE6]], %[[SLICE7]], %[[SLICE8]]) +// CHECK-NOT: linalg.matmul +// CHECK-NOT: linalg.elemwise_binary +// CHECK: scf.forall.in_parallel +// CHECK: linalg.elemwise_binary {fun = #linalg.binary_fn} +// CHECK: scf.forall.in_parallel + +// Declaration of the "microkernel" function that we will be targeting. +func.func private @microkernel( + %lhs: tensor<4x512xf32>, + %rhs: tensor<512x4xf32>, + %bias: tensor<4x4xf32>, + %init: tensor<4x4xf32>, + %output: tensor<4x4xf32>) -> tensor<4x4xf32> + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op, + %arg1: !transform.op<"linalg.matmul">, + %arg2: !transform.op<"linalg.elemwise_binary">): + // Since the %arg2 handle is associated with both elementwise operations, + // we need to split it into two handles so we can target only the second + // elementwise operation. + %add, %max = transform.split_handle %arg2 : (!transform.op<"linalg.elemwise_binary">) + -> (!transform.any_op, !transform.any_op) + + // The actual tiling transformation takes tile sizes as attributes. It produces a + // handle to the loop generated during tiling. + %loop, %tiled = transform.structured.tile_to_forall_op %max tile_sizes [8, 32] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // We can now fuse the other operations into the loop. Here, we fuse + // operations one-by-one. This requires the operation that is being fused + // to define the value used within the loop, so the order of such fusions + // is important. We could also use "transform.merge_handles" to obtain + // a single handle to all operations and give it to `fuse_into_containing_op` + // that would take care of the ordering in this case. + %add_fused, %loop2 = transform.structured.fuse_into_containing_op %add into %loop + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %matmul_fused, %loop3 = transform.structured.fuse_into_containing_op %arg1 into %loop2 + : (!transform.op<"linalg.matmul">, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Tile again to get the desired size. Note that this time this tiles the + // "add" operation and fuses matmul into the loop, but doesn't affect the + // "max" operation. This illustrates the precise targeting with the transform + // dialect. Otherwise, it is difficult to differentiate "add" and "max", both + // of which having the same kind. + %loop_second, %tiled_second = transform.structured.tile_to_forall_op %add_fused tile_sizes [4, 4] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %matmul_fused_2, %loop_second_2 = + transform.structured.fuse_into_containing_op %matmul_fused into %loop_second + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Since outlining is currently only implemented for region-holding operations + // such as loops, use tiling to size 1 to materialize the outer loop that is + // going to be outlined. + %loop_third, %_0 = transform.structured.tile_to_forall_op %tiled_second tile_sizes [1] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %_1, %outline_target = transform.structured.fuse_into_containing_op %matmul_fused_2 into %loop_third + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %func, %call = transform.loop.outline %outline_target {func_name = "outlined"} + : (!transform.any_op) -> (!transform.any_op, !transform.op<"func.call">) + + transform.yield +} diff --git a/mlir/test/Examples/transform/Ch2/invalid.mlir b/mlir/test/Examples/transform/Ch2/invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Examples/transform/Ch2/invalid.mlir @@ -0,0 +1,11 @@ +// RUN: transform-opt-ch2 %s --test-transform-dialect-interpreter --split-input-file --verify-diagnostics + +// expected-note @below {{offending payload}} +module { + transform.sequence failures(propagate) { + ^bb0(%arg0: !transform.any_op): + // expected-error @below {{only applies to func.call payloads}} + transform.my.change_call_target %arg0, "updated" : !transform.any_op + yield + } +} diff --git a/mlir/test/Examples/transform/Ch2/ops.mlir b/mlir/test/Examples/transform/Ch2/ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Examples/transform/Ch2/ops.mlir @@ -0,0 +1,26 @@ +// RUN: transform-opt-ch2 %s --test-transform-dialect-interpreter | FileCheck %s + +// ****************************** IMPORTANT NOTE ****************************** +// +// If you are changing this file, you may also need to change +// mlir/docs/Tutorials/Transform accordingly. +// +// **************************************************************************** + +func.func private @orig() +func.func private @updated() + +// CHECK-LABEL: func @test +func.func @test() { + // CHECK: call @updated + call @orig() : () -> () + return +} + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + %call = transform.structured.match ops{["func.call"]} in %arg0 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.my.change_call_target %{{.*}}, "updated" : !transform.any_op + transform.my.change_call_target %call, "updated" : !transform.any_op + transform.yield +} diff --git a/mlir/test/Examples/transform/Ch2/sequence.mlir b/mlir/test/Examples/transform/Ch2/sequence.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Examples/transform/Ch2/sequence.mlir @@ -0,0 +1,110 @@ +// RUN: transform-opt-ch2 %s \ +// RUN: --pass-pipeline="builtin.module(test-transform-dialect-interpreter{ \ +// RUN: bind-first-extra-to-ops=linalg.matmul \ +// RUN: bind-second-extra-to-ops=linalg.elemwise_binary \ +// RUN: enable-expensive-checks},canonicalize,cse,symbol-dce)" |\ +// RUN: FileCheck %s + +// ****************************** IMPORTANT NOTE ****************************** +// +// If you are changing this file, you may also need to change +// mlir/docs/Tutorials/Transform accordingly. +// +// **************************************************************************** + +// Original function to optimize. +func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { + // Matrix-matrix multiplication. + %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise addition. + %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise max with 0 (ReLU). + %c0f = arith.constant 0.0 : f32 + %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%biased, %c0f : tensor<512x512xf32>, f32) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + func.return %relued : tensor<512x512xf32> +} + +// CHECK-LABEL: func @fc_relu +// CHECK: scf.forall +// CHECK: scf.forall +// CHECK: %[[SLICE4:.+]] = tensor.extract_slice +// CHECK: %[[SLICE5:.+]] = tensor.extract_slice +// CHECK: %[[SLICE6:.+]] = tensor.extract_slice +// CHECK: %[[SLICE7:.+]] = tensor.extract_slice +// CHECK: %[[SLICE8:.+]] = tensor.extract_slice +// CHECK: func.call @microkernel(%[[SLICE4]], %[[SLICE5]], %[[SLICE6]], %[[SLICE7]], %[[SLICE8]]) +// CHECK-NOT: linalg.matmul +// CHECK-NOT: linalg.elemwise_binary +// CHECK: scf.forall.in_parallel +// CHECK: linalg.elemwise_binary {fun = #linalg.binary_fn} +// CHECK: scf.forall.in_parallel + +// Declaration of the "microkernel" function that we will be targeting. +func.func private @microkernel( + %lhs: tensor<4x512xf32>, + %rhs: tensor<512x4xf32>, + %bias: tensor<4x4xf32>, + %init: tensor<4x4xf32>, + %output: tensor<4x4xf32>) -> tensor<4x4xf32> + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op, + %arg1: !transform.op<"linalg.matmul">, + %arg2: !transform.op<"linalg.elemwise_binary">): + // Since the %arg2 handle is associated with both elementwise operations, + // we need to split it into two handles so we can target only the second + // elementwise operation. + %add, %max = transform.split_handle %arg2 : (!transform.op<"linalg.elemwise_binary">) + -> (!transform.any_op, !transform.any_op) + + // The actual tiling transformation takes tile sizes as attributes. It produces a + // handle to the loop generated during tiling. + %loop, %tiled = transform.structured.tile_to_forall_op %max tile_sizes [8, 32] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // We can now fuse the other operations into the loop. Here, we fuse + // operations one-by-one. This requires the operation that is being fused + // to define the value used within the loop, so the order of such fusions + // is important. We could also use "transform.merge_handles" to obtain + // a single handle to all operations and give it to `fuse_into_containing_op` + // that would take care of the ordering in this case. + %add_fused, %loop2 = transform.structured.fuse_into_containing_op %add into %loop + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %matmul_fused, %loop3 = transform.structured.fuse_into_containing_op %arg1 into %loop2 + : (!transform.op<"linalg.matmul">, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Tile again to get the desired size. Note that this time this tiles the + // "add" operation and fuses matmul into the loop, but doesn't affect the + // "max" operation. This illustrates the precise targeting with the transform + // dialect. Otherwise, it is difficult to differentiate "add" and "max", both + // of which having the same kind. + %loop_second, %tiled_second = transform.structured.tile_to_forall_op %add_fused tile_sizes [4, 4] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %matmul_fused_2, %loop_second_2 = + transform.structured.fuse_into_containing_op %matmul_fused into %loop_second + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Since outlining is currently only implemented for region-holding operations + // such as loops, use tiling to size 1 to materialize the outer loop that is + // going to be outlined. + %loop_third, %_0 = transform.structured.tile_to_forall_op %tiled_second tile_sizes [1] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %_1, %outline_target = transform.structured.fuse_into_containing_op %matmul_fused_2 into %loop_third + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %func, %call = transform.loop.outline %outline_target {func_name = "outlined"} + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Rewrite the call target. + transform.my.change_call_target %call, "microkernel" : !transform.any_op + + transform.yield +} diff --git a/mlir/test/Examples/transform/Ch3/invalid.mlir b/mlir/test/Examples/transform/Ch3/invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Examples/transform/Ch3/invalid.mlir @@ -0,0 +1,10 @@ +// RUN: transform-opt-ch3 %s --test-transform-dialect-interpreter --split-input-file --verify-diagnostics + +// expected-note @below {{offending operation}} +module { + transform.sequence failures(suppress) { + // expected-error @below {{expected the payload operation to implement CallOpInterface}} + ^bb0(%arg0: !transform.my.call_op_interface): + yield + } +} diff --git a/mlir/test/Examples/transform/Ch3/ops.mlir b/mlir/test/Examples/transform/Ch3/ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Examples/transform/Ch3/ops.mlir @@ -0,0 +1,46 @@ +// RUN: transform-opt-ch3 %s --test-transform-dialect-interpreter \ +// RUN: --allow-unregistered-dialect --split-input-file | FileCheck %s + +// ****************************** IMPORTANT NOTE ****************************** +// +// If you are changing this file, you may also need to change +// mlir/docs/Tutorials/Transform accordingly. +// +// **************************************************************************** + +func.func private @orig() +func.func private @updated() + +// CHECK-LABEL: func @test1 +func.func @test1() { + // CHECK: call @updated + call @orig() : () -> () + return +} + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + %call = transform.structured.match ops{["func.call"]} in %arg0 : (!transform.any_op) -> !transform.op<"func.call"> + // CHECK: transform.my.change_call_target %{{.*}}, "updated" : !transform.op<"func.call"> + transform.my.change_call_target %call, "updated" : !transform.op<"func.call"> + transform.yield +} + +// ----- + +func.func private @orig() + +// CHECK-LABEL: func @test2 +func.func @test2() { + // CHECK: "my.mm4" + call @orig() : () -> () + return +} + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + %call = transform.structured.match ops{["func.call"]} in %arg0 : (!transform.any_op) -> !transform.my.call_op_interface + // CHECK: transform.my.call_to_op %{{.*}} : (!transform.my.call_op_interface) -> !transform.any_op + transform.my.call_to_op %call : (!transform.my.call_op_interface) -> !transform.any_op + transform.yield +} diff --git a/mlir/test/Examples/transform/Ch3/sequence.mlir b/mlir/test/Examples/transform/Ch3/sequence.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Examples/transform/Ch3/sequence.mlir @@ -0,0 +1,110 @@ +// RUN: transform-opt-ch2 %s \ +// RUN: --pass-pipeline="builtin.module(test-transform-dialect-interpreter{ \ +// RUN: bind-first-extra-to-ops=linalg.matmul \ +// RUN: bind-second-extra-to-ops=linalg.elemwise_binary \ +// RUN: enable-expensive-checks},canonicalize,cse,symbol-dce)" |\ +// RUN: FileCheck %s + +// ****************************** IMPORTANT NOTE ****************************** +// +// If you are changing this file, you may also need to change +// mlir/docs/Tutorials/Transform accordingly. +// +// **************************************************************************** + +// Original function to optimize. +func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { + // Matrix-matrix multiplication. + %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise addition. + %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise max with 0 (ReLU). + %c0f = arith.constant 0.0 : f32 + %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%biased, %c0f : tensor<512x512xf32>, f32) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + func.return %relued : tensor<512x512xf32> +} + +// CHECK-LABEL: func @fc_relu +// CHECK: scf.forall +// CHECK: scf.forall +// CHECK: %[[SLICE4:.+]] = tensor.extract_slice +// CHECK: %[[SLICE5:.+]] = tensor.extract_slice +// CHECK: %[[SLICE6:.+]] = tensor.extract_slice +// CHECK: %[[SLICE7:.+]] = tensor.extract_slice +// CHECK: %[[SLICE8:.+]] = tensor.extract_slice +// CHECK: func.call @microkernel(%[[SLICE4]], %[[SLICE5]], %[[SLICE6]], %[[SLICE7]], %[[SLICE8]]) +// CHECK-NOT: linalg.matmul +// CHECK-NOT: linalg.elemwise_binary +// CHECK: scf.forall.in_parallel +// CHECK: linalg.elemwise_binary {fun = #linalg.binary_fn} +// CHECK: scf.forall.in_parallel + +// Declaration of the "microkernel" function that we will be targeting. +func.func private @microkernel( + %lhs: tensor<4x512xf32>, + %rhs: tensor<512x4xf32>, + %bias: tensor<4x4xf32>, + %init: tensor<4x4xf32>, + %output: tensor<4x4xf32>) -> tensor<4x4xf32> + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op, + %arg1: !transform.op<"linalg.matmul">, + %arg2: !transform.op<"linalg.elemwise_binary">): + // Since the %arg2 handle is associated with both elementwise operations, + // we need to split it into two handles so we can target only the second + // elementwise operation. + %add, %max = transform.split_handle %arg2 : (!transform.op<"linalg.elemwise_binary">) + -> (!transform.any_op, !transform.any_op) + + // The actual tiling transformation takes tile sizes as attributes. It produces a + // handle to the loop generated during tiling. + %loop, %tiled = transform.structured.tile_to_forall_op %max tile_sizes [8, 32] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // We can now fuse the other operations into the loop. Here, we fuse + // operations one-by-one. This requires the operation that is being fused + // to define the value used within the loop, so the order of such fusions + // is important. We could also use "transform.merge_handles" to obtain + // a single handle to all operations and give it to `fuse_into_containing_op` + // that would take care of the ordering in this case. + %add_fused, %loop2 = transform.structured.fuse_into_containing_op %add into %loop + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %matmul_fused, %loop3 = transform.structured.fuse_into_containing_op %arg1 into %loop2 + : (!transform.op<"linalg.matmul">, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Tile again to get the desired size. Note that this time this tiles the + // "add" operation and fuses matmul into the loop, but doesn't affect the + // "max" operation. This illustrates the precise targeting with the transform + // dialect. Otherwise, it is difficult to differentiate "add" and "max", both + // of which having the same kind. + %loop_second, %tiled_second = transform.structured.tile_to_forall_op %add_fused tile_sizes [4, 4] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %matmul_fused_2, %loop_second_2 = + transform.structured.fuse_into_containing_op %matmul_fused into %loop_second + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Since outlining is currently only implemented for region-holding operations + // such as loops, use tiling to size 1 to materialize the outer loop that is + // going to be outlined. + %loop_third, %_0 = transform.structured.tile_to_forall_op %tiled_second tile_sizes [1] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %_1, %outline_target = transform.structured.fuse_into_containing_op %matmul_fused_2 into %loop_third + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %func, %call = transform.loop.outline %outline_target {func_name = "outlined"} + : (!transform.any_op) -> (!transform.any_op, !transform.op<"func.call">) + + // Rewrite the call target. + transform.my.change_call_target %call, "microkernel" : !transform.op<"func.call"> + + transform.yield +} diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -128,6 +128,8 @@ ToolSubst("toyc-ch5", unresolved="ignore"), ToolSubst("toyc-ch6", unresolved="ignore"), ToolSubst("toyc-ch7", unresolved="ignore"), + ToolSubst('transform-opt-ch2', unresolved='ignore'), + ToolSubst('transform-opt-ch3', unresolved='ignore'), ToolSubst("%mlir_lib_dir", config.mlir_lib_dir, unresolved="ignore"), ToolSubst("%mlir_src_dir", config.mlir_src_root, unresolved="ignore"), ]