diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md --- a/mlir/docs/Bindings/Python.md +++ b/mlir/docs/Bindings/Python.md @@ -442,7 +442,7 @@ attributes: * `OPERATION_NAME` attribute with the `str` fully qualified operation name - (i.e. `std.absf`). + (i.e. `math.abs`). * An `__init__` method for the *default builder* if one is defined or inferred for the operation. * `@property` getter for each operand or result (using an auto-generated name diff --git a/mlir/docs/BufferDeallocationInternals.md b/mlir/docs/BufferDeallocationInternals.md --- a/mlir/docs/BufferDeallocationInternals.md +++ b/mlir/docs/BufferDeallocationInternals.md @@ -72,8 +72,8 @@ ```mlir func @mixedAllocation(%arg0: i1) { - %0 = alloca() : memref<2xf32> // aliases: %2 - %1 = alloc() : memref<2xf32> // aliases: %2 + %0 = memref.alloca() : memref<2xf32> // aliases: %2 + %1 = memref.alloc() : memref<2xf32> // aliases: %2 cond_br %arg0, ^bb1, ^bb2 ^bb1: use(%0) @@ -405,7 +405,7 @@ ```mlir func @nested_region_control_flow(%arg0 : index, %arg1 : index) -> memref { - %0 = cmpi "eq", %arg0, %arg1 : index + %0 = arith.cmpi "eq", %arg0, %arg1 : index %1 = memref.alloc(%arg0, %arg0) : memref %2 = scf.if %0 -> (memref) { scf.yield %1 : memref // %2 will be an alias of %1 @@ -426,7 +426,7 @@ ```mlir func @nested_region_control_flow(%arg0: index, %arg1: index) -> memref { - %0 = cmpi "eq", %arg0, %arg1 : index + %0 = arith.cmpi "eq", %arg0, %arg1 : index %1 = memref.alloc(%arg0, %arg0) : memref %2 = scf.if %0 -> (memref) { scf.yield %1 : memref @@ -518,7 +518,7 @@ %res: memref<2xf32>) { %0 = scf.for %i = %lb to %ub step %step iter_args(%iterBuf = %buf) -> memref<2xf32> { - %1 = cmpi "eq", %i, %ub : index + %1 = arith.cmpi "eq", %i, %ub : index %2 = scf.if %1 -> (memref<2xf32>) { %3 = memref.alloc() : memref<2xf32> // makes %2 a critical alias due to a // divergent allocation @@ -557,7 +557,7 @@ %4 = memref.clone %buf : (memref<2xf32>) -> (memref<2xf32>) %0 = scf.for %i = %lb to %ub step %step iter_args(%iterBuf = %4) -> memref<2xf32> { - %1 = cmpi "eq", %i, %ub : index + %1 = arith.cmpi "eq", %i, %ub : index %2 = scf.if %1 -> (memref<2xf32>) { %3 = memref.alloc() : memref<2xf32> // makes %2 a critical alias use(%3) @@ -666,7 +666,7 @@ indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0, %temp { ^bb0(%gen2_arg0: f32, %gen2_arg1: f32): - %tmp2 = exp %gen2_arg0 : f32 + %tmp2 = math.exp %gen2_arg0 : f32 test.yield %tmp2 : f32 }: memref<2xf32>, memref<2xf32> %result = memref.clone %temp : (memref<2xf32>) -> (memref<2xf32>) @@ -685,7 +685,7 @@ indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0, %result { ^bb0(%gen2_arg0: f32, %gen2_arg1: f32): - %tmp2 = exp %gen2_arg0 : f32 + %tmp2 = math.exp %gen2_arg0 : f32 test.yield %tmp2 : f32 }: memref<2xf32>, memref<2xf32> return diff --git a/mlir/docs/Bufferization.md b/mlir/docs/Bufferization.md --- a/mlir/docs/Bufferization.md +++ b/mlir/docs/Bufferization.md @@ -190,8 +190,8 @@ `BufferizeTypeConverter`, which comes pre-loaded with the necessary conversions and materializations between `tensor` and `memref`. -In this case, the `MemRefOpsDialect` is marked as legal, so the `tensor_load` -and `buffer_cast` ops, which are inserted automatically by the dialect +In this case, the `MemRefOpsDialect` is marked as legal, so the `memref.tensor_load` +and `memref.buffer_cast` ops, which are inserted automatically by the dialect conversion framework as materializations, are legal. There is a helper `populateBufferizeMaterializationLegality` ([code](https://github.com/llvm/llvm-project/blob/a0b65a7bcd6065688189b3d678c42ed6af9603db/mlir/include/mlir/Transforms/Bufferize.h#L53)) @@ -199,46 +199,50 @@ ### Other partial bufferization examples -- `linalg-bufferize` - ([code](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp#L1), - [test](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/test/Dialect/Linalg/bufferize.mlir#L1)) - - - Bufferizes the `linalg` dialect. - - This is an example of how to simultaneously bufferize all the ops that - satisfy a certain OpInterface with a single pattern. Specifically, - `BufferizeAnyLinalgOp` - ([code](https://github.com/llvm/llvm-project/blob/daaaed6bb89044ac58a23f1bb1ccdd12342a5a58/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp#L170)) - bufferizes any ops that implements the `LinalgOp` interface. - -- `scf-bufferize` - ([code](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp#L1), - [test](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/test/Dialect/SCF/bufferize.mlir#L1)) - - - Bufferizes ops from the `scf` dialect. - - This is an example of how to bufferize ops that implement - `RegionBranchOpInterface` (that is, they use regions to represent control - flow). - - The bulk of the work is done by - `lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp` - ([code](https://github.com/llvm/llvm-project/blob/daaaed6bb89044ac58a23f1bb1ccdd12342a5a58/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp#L1)), - which is well-commented and covers how to correctly convert ops that contain - regions. - -- `func-bufferize` - ([code](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp#L1), - [test](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/test/Dialect/Standard/func-bufferize.mlir#L1)) - - - Bufferizes `func`, `call`, and `BranchOpInterface` ops. - - This is an example of how to bufferize ops that have multi-block regions. - - This is an example of a pass that is not split along dialect subdivisions. - -- `tensor-constant-bufferize` - ([code](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp#L1), - [test](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir#L1)) - - Bufferizes only `std.constant` ops of `tensor` type. - - This is an example of setting up the legality so that only a subset of - `std.constant` ops get bufferized. - - This is an example of a pass that is not split along dialect subdivisions. +- `linalg-bufferize` + ([code](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp#L1), + [test](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/test/Dialect/Linalg/bufferize.mlir#L1)) + + - Bufferizes the `linalg` dialect. + - This is an example of how to simultaneously bufferize all the ops that + satisfy a certain OpInterface with a single pattern. Specifically, + `BufferizeAnyLinalgOp` + ([code](https://github.com/llvm/llvm-project/blob/daaaed6bb89044ac58a23f1bb1ccdd12342a5a58/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp#L170)) + bufferizes any ops that implements the `LinalgOp` interface. + +- `scf-bufferize` + ([code](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp#L1), + [test](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/test/Dialect/SCF/bufferize.mlir#L1)) + + - Bufferizes ops from the `scf` dialect. + - This is an example of how to bufferize ops that implement + `RegionBranchOpInterface` (that is, they use regions to represent + control flow). + - The bulk of the work is done by + `lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp` + ([code](https://github.com/llvm/llvm-project/blob/daaaed6bb89044ac58a23f1bb1ccdd12342a5a58/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp#L1)), + which is well-commented and covers how to correctly convert ops that + contain regions. + +- `func-bufferize` + ([code](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp#L1), + [test](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/test/Dialect/Standard/func-bufferize.mlir#L1)) + + - Bufferizes `func`, `call`, and `BranchOpInterface` ops. + - This is an example of how to bufferize ops that have multi-block + regions. + - This is an example of a pass that is not split along dialect + subdivisions. + +- `tensor-constant-bufferize` + ([code](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp#L1), + [test](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir#L1)) + + - Bufferizes only `arith.constant` ops of `tensor` type. + - This is an example of setting up the legality so that only a subset of + `std.constant` ops get bufferized. + - This is an example of a pass that is not split along dialect + subdivisions. ## How to write a finalizing bufferization pass @@ -246,10 +250,10 @@ from the program. The easiest way to write a finalizing bufferize pass is to not write one at all! -MLIR provides a pass `finalizing-bufferize` which eliminates the `tensor_load` / -`buffer_cast` materialization ops inserted by partial bufferization passes -and emits an error if that is not sufficient to remove all tensors from the -program. +MLIR provides a pass `finalizing-bufferize` which eliminates the `memref.tensor_load` / +`memref.buffer_cast` materialization ops inserted by partial bufferization +passes and emits an error if that is not sufficient to remove all tensors from +the program. This pass is sufficient when partial bufferization passes have bufferized all the ops in the program, leaving behind only the materializations. When possible, @@ -267,8 +271,8 @@ recommended in new code. A helper, `populateEliminateBufferizeMaterializationsPatterns` ([code](https://github.com/llvm/llvm-project/blob/a0b65a7bcd6065688189b3d678c42ed6af9603db/mlir/include/mlir/Transforms/Bufferize.h#L58)) -is available for such passes to provide patterns that eliminate `tensor_load` -and `buffer_cast`. +is available for such passes to provide patterns that eliminate `memref.tensor_load` +and `memref.buffer_cast`. ## Changes since [the talk](#the-talk) diff --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md --- a/mlir/docs/DeclarativeRewrites.md +++ b/mlir/docs/DeclarativeRewrites.md @@ -486,17 +486,17 @@ want to allocate memory and store some computation (in pseudocode): ```mlir -%dst = addi %lhs, %rhs +%dst = arith.addi %lhs, %rhs ``` into ```mlir %shape = shape %lhs -%mem = alloc %shape -%sum = addi %lhs, %rhs -store %mem, %sum -%dst = load %mem +%mem = memref.alloc %shape +%sum = arith.addi %lhs, %rhs +memref.store %mem, %sum +%dst = memref.load %mem ``` We cannot fit in with just one result pattern given `store` does not return a diff --git a/mlir/docs/Diagnostics.md b/mlir/docs/Diagnostics.md --- a/mlir/docs/Diagnostics.md +++ b/mlir/docs/Diagnostics.md @@ -301,7 +301,7 @@ // Expect an error on an adjacent line. func @foo(%a : f32) { // expected-error@+1 {{unknown comparison predicate "foo"}} - %result = cmpf "foo", %a, %a : f32 + %result = arith.cmpf "foo", %a, %a : f32 return } diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md --- a/mlir/docs/DialectConversion.md +++ b/mlir/docs/DialectConversion.md @@ -66,7 +66,7 @@ - This action signals that only some instances of a given operation are legal. This allows for defining fine-tune constraints, e.g. saying that - `addi` is only legal when operating on 32-bit integers. + `arith.addi` is only legal when operating on 32-bit integers. * Illegal diff --git a/mlir/docs/Dialects/Affine.md b/mlir/docs/Dialects/Affine.md --- a/mlir/docs/Dialects/Affine.md +++ b/mlir/docs/Dialects/Affine.md @@ -54,7 +54,7 @@ ```mlir #affine_map2to3 = affine_map<(d0, d1)[s0] -> (d0, d1 + s0, d1 - s0)> // Binds %N to the s0 symbol in affine_map2to3. -%x = alloc()[%N] : memref<40x50xf32, #affine_map2to3> +%x = memref.alloc()[%N] : memref<40x50xf32, #affine_map2to3> ``` ### Restrictions on Dimensions and Symbols @@ -192,10 +192,10 @@ // Use an affine mapping definition in an alloc operation, binding the // SSA value %N to the symbol s0. -%a = alloc()[%N] : memref<4x4xf32, #affine_map42> +%a = memref.alloc()[%N] : memref<4x4xf32, #affine_map42> // Same thing with an inline affine mapping definition. -%b = alloc()[%N] : memref<4x4xf32, affine_map<(d0, d1)[s0] -> (d0, d0 + d1 + s0 floordiv 2)>> +%b = memref.alloc()[%N] : memref<4x4xf32, affine_map<(d0, d1)[s0] -> (d0, d0 + d1 + s0 floordiv 2)>> ``` ### Semi-affine maps @@ -403,8 +403,8 @@ space 1 at indices [%k + 7, %l], would be specified as follows: %num_elements = constant 256 - %idx = constant 0 : index - %tag = alloc() : memref<1xi32, 4> + %idx = arith.constant 0 : index + %tag = memref.alloc() : memref<1xi32, 4> affine.dma_start %src[%i + 3, %j], %dst[%k + 7, %l], %tag[%idx], %num_elements : memref<40x128xf32, 0>, memref<2x1024xf32, 1>, memref<1xi32, 2> diff --git a/mlir/docs/Dialects/Linalg/_index.md b/mlir/docs/Dialects/Linalg/_index.md --- a/mlir/docs/Dialects/Linalg/_index.md +++ b/mlir/docs/Dialects/Linalg/_index.md @@ -125,14 +125,14 @@ #map0 = affine_map<(d0) -> (d0 * 2 + 1)> func @example(%arg0: memref, %arg1: memref, #map0>) { - %c0 = constant 0 : index - %c1 = constant 1 : index - %0 = dim %arg0, %c0 : memref + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = memref.dim %arg0, %c0 : memref scf.for %arg2 = %c0 to %0 step %c1 { - %1 = load %arg0[%arg2] : memref - %2 = load %arg1[%arg2] : memref, #map0> + %1 = memref.load %arg0[%arg2] : memref + %2 = memref.load %arg1[%arg2] : memref, #map0> %3 = "some_compute"(%1, %2) : (f32, vector<4xf32>) -> vector<4xf32> - store %3, %arg1[%arg2] : memref, #map0> + memref.store %3, %arg1[%arg2] : memref, #map0> } return } @@ -207,16 +207,16 @@ #map0 = affine_map<(d0, d1) -> (d0 * 2 + d1 * 2)> func @example(%arg0: memref<8x?xf32, #map0>, %arg1: memref>) { - %c8 = constant 8 : index - %c0 = constant 0 : index - %c1 = constant 1 : index - %0 = dim %arg0, %c1 : memref<8x?xf32, #map0> + %c8 = arith.constant 8 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = memref.dim %arg0, %c1 : memref<8x?xf32, #map0> scf.for %arg2 = %c0 to %0 step %c1 { scf.for %arg3 = %c0 to %c8 step %c1 { - %1 = load %arg0[%arg3, %arg2] : memref<8x?xf32, #map0> - %2 = load %arg1[%arg3] : memref> + %1 = memref.load %arg0[%arg3, %arg2] : memref<8x?xf32, #map0> + %2 = memref.load %arg1[%arg3] : memref> %3 = "some_compute"(%1, %2) : (f32, vector<4xf32>) -> vector<4xf32> - store %3, %arg1[%arg3] : memref> + memref.store %3, %arg1[%arg3] : memref> } } return @@ -314,7 +314,7 @@ ins(%A, %B: memref, memref) outs(%C: memref) { ^bb0(%a: f32, %b: f32, %c: f32): - %d = addf %a, %b : f32 + %d = arith.addf %a, %b : f32 linalg.yield %d : f32 } @@ -330,16 +330,16 @@ ```mlir func @example(%arg0: memref, %arg1: memref, %arg2: memref) { - %c0 = constant 0 : index - %c1 = constant 1 : index - %0 = dim %arg0, %c0 : memref - %1 = dim %arg0, %c1 : memref + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = memref.dim %arg0, %c0 : memref + %1 = memref.dim %arg0, %c1 : memref scf.for %arg3 = %c0 to %0 step %c1 { scf.for %arg4 = %c0 to %1 step %c1 { - %2 = load %arg0[%arg3, %arg4] : memref - %3 = load %arg1[%arg3, %arg4] : memref - %4 = addf %2, %3 : f32 - store %4, %arg2[%arg3, %arg4] : memref + %2 = memref.load %arg0[%arg3, %arg4] : memref + %3 = memref.load %arg1[%arg3, %arg4] : memref + %4 = arith.addf %2, %3 : f32 + memref.store %4, %arg2[%arg3, %arg4] : memref } } return @@ -387,7 +387,7 @@ ins(%A, %B: memref, memref) outs(%C: memref) { ^bb0(%a: f32, %b: f32, %c: f32): - %d = addf %a, %b : f32 + %d = arith.addf %a, %b : f32 linalg.yield %d : f32 } return @@ -518,7 +518,7 @@ ``` * `memref.view`, -* `std.subview`, +* `memref.subview`, * `memref.transpose`. * `linalg.range`, * `linalg.slice`, diff --git a/mlir/docs/Dialects/MemRef.md b/mlir/docs/Dialects/MemRef.md --- a/mlir/docs/Dialects/MemRef.md +++ b/mlir/docs/Dialects/MemRef.md @@ -16,7 +16,7 @@ Syntax: ``` -operation ::= `dma_start` ssa-use`[`ssa-use-list`]` `,` +operation ::= `memref.dma_start` ssa-use`[`ssa-use-list`]` `,` ssa-use`[`ssa-use-list`]` `,` ssa-use `,` ssa-use`[`ssa-use-list`]` (`,` ssa-use `,` ssa-use)? `:` memref-type `,` memref-type `,` memref-type @@ -39,17 +39,17 @@ destination memref need not be of the same dimensionality, but need to have the same elemental type. -For example, a `dma_start` operation that transfers 32 vector elements from a -memref `%src` at location `[%i, %j]` to memref `%dst` at `[%k, %l]` would be -specified as shown below. +For example, a `memref.dma_start` operation that transfers 32 vector elements +from a memref `%src` at location `[%i, %j]` to memref `%dst` at `[%k, %l]` +would be specified as shown below. Example: ```mlir -%size = constant 32 : index -%tag = alloc() : memref<1 x i32, affine_map<(d0) -> (d0)>, 4> -%idx = constant 0 : index -dma_start %src[%i, %j], %dst[%k, %l], %size, %tag[%idx] : +%size = arith.constant 32 : index +%tag = memref.alloc() : memref<1 x i32, affine_map<(d0) -> (d0)>, 4> +%idx = arith.constant 0 : index +memref.dma_start %src[%i, %j], %dst[%k, %l], %size, %tag[%idx] : memref<40 x 8 x vector<16xf32>, affine_map<(d0, d1) -> (d0, d1)>, 0>, memref<2 x 4 x vector<16xf32>, affine_map<(d0, d1) -> (d0, d1)>, 2>, memref<1 x i32>, affine_map<(d0) -> (d0)>, 4> @@ -60,7 +60,7 @@ Syntax: ``` -operation ::= `dma_wait` ssa-use`[`ssa-use-list`]` `,` ssa-use `:` memref-type +operation ::= `memref.dma_wait` ssa-use`[`ssa-use-list`]` `,` ssa-use `:` memref-type ``` Blocks until the completion of a DMA operation associated with the tag element @@ -72,5 +72,5 @@ Example: ```mlir -dma_wait %tag[%idx], %size : memref<1 x i32, affine_map<(d0) -> (d0)>, 4> +memref.dma_wait %tag[%idx], %size : memref<1 x i32, affine_map<(d0) -> (d0)>, 4> ``` diff --git a/mlir/docs/Dialects/Vector.md b/mlir/docs/Dialects/Vector.md --- a/mlir/docs/Dialects/Vector.md +++ b/mlir/docs/Dialects/Vector.md @@ -95,8 +95,8 @@ ### Virtual Vector Ops Some existing Standard and Vector Dialect on `n-D` `vector` types comprise: ``` -%2 = std.addf %0, %1 : vector<3x7x8xf32> // -> vector<3x7x8xf32> -%2 = std.mulf %0, %1 : vector<3x7x8xf32> // -> vector<3x7x8xf32> +%2 = arith.addf %0, %1 : vector<3x7x8xf32> // -> vector<3x7x8xf32> +%2 = arith.mulf %0, %1 : vector<3x7x8xf32> // -> vector<3x7x8xf32> %2 = std.splat %1 : vector<3x7x8xf32> // -> vector<3x7x8xf32> %1 = vector.extract %0[1]: vector<3x7x8xf32> // -> vector<7x8xf32> diff --git a/mlir/docs/Dialects/emitc.md b/mlir/docs/Dialects/emitc.md --- a/mlir/docs/Dialects/emitc.md +++ b/mlir/docs/Dialects/emitc.md @@ -23,13 +23,15 @@ Besides operations part of the EmitC dialect, the Cpp targets supports translating the following operations: -* 'std' Dialect - * `std.br` - * `std.call` - * `std.cond_br` - * `std.constant` - * `std.return` -* 'scf' Dialect - * `scf.for` - * `scf.if` - * `scf.yield` +* 'std' Dialect + * `std.br` + * `std.call` + * `std.cond_br` + * `std.constant` + * `std.return` +* 'scf' Dialect + * `scf.for` + * `scf.if` + * `scf.yield` +* 'arith' Dialect + * 'arith.constant' diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md --- a/mlir/docs/LangRef.md +++ b/mlir/docs/LangRef.md @@ -76,26 +76,26 @@ // known. The shapes are assumed to match. func @mul(%A: tensor<100x?xf32>, %B: tensor) -> (tensor<100x50xf32>) { // Compute the inner dimension of %A using the dim operation. - %n = dim %A, 1 : tensor<100x?xf32> + %n = memref.dim %A, 1 : tensor<100x?xf32> // Allocate addressable "buffers" and copy tensors %A and %B into them. - %A_m = alloc(%n) : memref<100x?xf32> - tensor_store %A to %A_m : memref<100x?xf32> + %A_m = memref.alloc(%n) : memref<100x?xf32> + memref.tensor_store %A to %A_m : memref<100x?xf32> - %B_m = alloc(%n) : memref - tensor_store %B to %B_m : memref + %B_m = memref.alloc(%n) : memref + memref.tensor_store %B to %B_m : memref // Call function @multiply passing memrefs as arguments, // and getting returned the result of the multiplication. %C_m = call @multiply(%A_m, %B_m) : (memref<100x?xf32>, memref) -> (memref<100x50xf32>) - dealloc %A_m : memref<100x?xf32> - dealloc %B_m : memref + memref.dealloc %A_m : memref<100x?xf32> + memref.dealloc %B_m : memref // Load the buffer data into a higher level "tensor" value. - %C = tensor_load %C_m : memref<100x50xf32> - dealloc %C_m : memref<100x50xf32> + %C = memref.tensor_load %C_m : memref<100x50xf32> + memref.dealloc %C_m : memref<100x50xf32> // Call TensorFlow built-in function to print the result tensor. "tf.Print"(%C){message: "mul result"} @@ -108,22 +108,22 @@ func @multiply(%A: memref<100x?xf32>, %B: memref) -> (memref<100x50xf32>) { // Compute the inner dimension of %A. - %n = dim %A, 1 : memref<100x?xf32> + %n = memref.dim %A, 1 : memref<100x?xf32> // Allocate memory for the multiplication result. - %C = alloc() : memref<100x50xf32> + %C = memref.alloc() : memref<100x50xf32> // Multiplication loop nest. affine.for %i = 0 to 100 { affine.for %j = 0 to 50 { - store 0 to %C[%i, %j] : memref<100x50xf32> + memref.store 0 to %C[%i, %j] : memref<100x50xf32> affine.for %k = 0 to %n { - %a_v = load %A[%i, %k] : memref<100x?xf32> - %b_v = load %B[%k, %j] : memref - %prod = mulf %a_v, %b_v : f32 - %c_v = load %C[%i, %j] : memref<100x50xf32> - %sum = addf %c_v, %prod : f32 - store %sum, %C[%i, %j] : memref<100x50xf32> + %a_v = memref.load %A[%i, %k] : memref<100x?xf32> + %b_v = memref.load %B[%k, %j] : memref + %prod = arith.mulf %a_v, %b_v : f32 + %c_v = memref.load %C[%i, %j] : memref<100x50xf32> + %sum = arith.addf %c_v, %prod : f32 + memref.store %sum, %C[%i, %j] : memref<100x50xf32> } } } @@ -389,7 +389,7 @@ br ^bb3(%a: i64) // Branch passes %a as the argument ^bb2: - %b = addi %a, %a : i64 + %b = arith.addi %a, %a : i64 br ^bb3(%b: i64) // Branch passes %b as the argument // ^bb3 receives an argument, named %c, from predecessors @@ -400,7 +400,7 @@ br ^bb4(%c, %a : i64, i64) ^bb4(%d : i64, %e : i64): - %0 = addi %d, %e : i64 + %0 = arith.addi %d, %e : i64 return %0 : i64 // Return is also a terminator. } ``` @@ -756,7 +756,7 @@ - *inherent attributes* are inherent to the definition of an operation's semantics. The operation itself is expected to verify the consistency of these - attributes. An example is the `predicate` attribute of the `std.cmpi` op. + attributes. An example is the `predicate` attribute of the `arith.cmpi` op. These attributes must have names that do not start with a dialect prefix. - *discardable attributes* have semantics defined externally to the operation diff --git a/mlir/docs/Rationale/MLIRForGraphAlgorithms.md b/mlir/docs/Rationale/MLIRForGraphAlgorithms.md --- a/mlir/docs/Rationale/MLIRForGraphAlgorithms.md +++ b/mlir/docs/Rationale/MLIRForGraphAlgorithms.md @@ -156,7 +156,7 @@ ```mlir // RUN: mlir-opt %s -canonicalize | FileCheck %s func @test_subi_zero_cfg(%arg0: i32) -> i32 { - %y = subi %arg0, %arg0 : i32 + %y = arith.subi %arg0, %arg0 : i32 return %y: i32 } // CHECK-LABEL: func @test_subi_zero_cfg(%arg0: i32) @@ -210,13 +210,13 @@ ```mlir // RUN: mlir-opt %s -memref-dependence-check -verify-diagnostics func @different_memrefs() { - %m.a = alloc() : memref<100xf32> - %m.b = alloc() : memref<100xf32> - %c0 = constant 0 : index - %c1 = constant 1.0 : f32 - store %c1, %m.a[%c0] : memref<100xf32> + %m.a = memref.alloc() : memref<100xf32> + %m.b = memref.alloc() : memref<100xf32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1.0 : f32 + memref.store %c1, %m.a[%c0] : memref<100xf32> // expected-note@-1 {{dependence from memref access 0 to access 1 = false}} - %v0 = load %m.b[%c0] : memref<100xf32> + %v0 = memref.load %m.b[%c0] : memref<100xf32> return } ``` diff --git a/mlir/docs/Rationale/Rationale.md b/mlir/docs/Rationale/Rationale.md --- a/mlir/docs/Rationale/Rationale.md +++ b/mlir/docs/Rationale/Rationale.md @@ -136,7 +136,7 @@ ```mlir func foo(...) { - %A = alloc <8x?xf32, #lmap> (%N) + %A = memref.alloc <8x?xf32, #lmap> (%N) ... call bar(%A) : (memref<8x?xf32, #lmap>) } @@ -145,7 +145,7 @@ // Type of %A indicates that %A has dynamic shape with 8 rows // and unknown number of columns. The number of columns is queried // dynamically using dim instruction. - %N = dim %A, 1 : memref<8x?xf32, #lmap> + %N = memref.dim %A, 1 : memref<8x?xf32, #lmap> affine.for %i = 0 to 8 { affine.for %j = 0 to %N { @@ -250,8 +250,9 @@ For the standard dialect, the choice is to have signless integer types. An integer value does not have an intrinsic sign, and it's up to the specific op -for interpretation. For example, ops like `addi` and `muli` do two's complement -arithmetic, but some other operations get a sign, e.g. `divis` vs `diviu`. +for interpretation. For example, ops like `arith.addi` and `arith.muli` do +two's complement arithmetic, but some other operations get a sign, e.g. +`arith.divsi` vs `arith.divui`. LLVM uses the [same design](http://llvm.org/docs/LangRef.html#integer-type), which was introduced in a revamp rolled out @@ -279,9 +280,9 @@ ### Splitting floating point vs integer operations -The MLIR "standard" operation set splits many integer and floating point -operations into different categories, for example `addf` vs `addi` and `cmpf` vs -`cmpi` +The MLIR "Arithmetic" dialect splits many integer and floating point operations +into different categories, for example `arith.addf` vs `arith.addi` and +`arith.cmpf` vs `arith.cmpi` ([following the design of LLVM](http://llvm.org/docs/LangRef.html#binary-operations)). These instructions _are_ polymorphic on the number of elements in the type though, for example `addf` is used with scalar floats, vectors of floats, and @@ -547,7 +548,7 @@ ```mlir func @search(%A: memref, %S: , %key : i32) { - %ni = dim %A, 0 : memref + %ni = memref.dim %A, 0 : memref // This loop can be parallelized affine.for %i = 0 to %ni { call @search_body (%A, %S, %key, %i) : (memref, memref, i32, i32) @@ -556,16 +557,16 @@ } func @search_body(%A: memref, %S: memref, %key: i32, %i : i32) { - %nj = dim %A, 1 : memref + %nj = memref.dim %A, 1 : memref br ^bb1(0) ^bb1(%j: i32) - %p1 = cmpi "lt", %j, %nj : i32 + %p1 = arith.cmpi "lt", %j, %nj : i32 cond_br %p1, ^bb2, ^bb5 ^bb2: %v = affine.load %A[%i, %j] : memref - %p2 = cmpi "eq", %v, %key : i32 + %p2 = arith.cmpi "eq", %v, %key : i32 cond_br %p2, ^bb3(%j), ^bb4 ^bb3(%j: i32) @@ -573,7 +574,7 @@ br ^bb5 ^bb4: - %jinc = addi %j, 1 : i32 + %jinc = arith.addi %j, 1 : i32 br ^bb1(%jinc) ^bb5: @@ -844,7 +845,7 @@ bb0 (%0, %1: memref<128xf32>, i64): %val = affine.load %A [%pos] %val = affine.load %A [%pos + 1] - %p = mulf %val, %val : f32 + %p = arith.mulf %val, %val : f32 return %p : f32 } ``` diff --git a/mlir/docs/SPIRVToLLVMDialectConversion.md b/mlir/docs/SPIRVToLLVMDialectConversion.md --- a/mlir/docs/SPIRVToLLVMDialectConversion.md +++ b/mlir/docs/SPIRVToLLVMDialectConversion.md @@ -857,7 +857,7 @@ func @main() { // Fill the buffer with some data - %buffer = alloc : memref<8xi32> + %buffer = memref.alloc : memref<8xi32> %data = ... call fillBuffer(%buffer, %data) @@ -880,7 +880,7 @@ func @main() { // Fill the buffer with some data. - %buffer = alloc : memref<8xi32> + %buffer = memref.alloc : memref<8xi32> %data = ... call fillBuffer(%buffer, %data) diff --git a/mlir/docs/SymbolsAndSymbolTables.md b/mlir/docs/SymbolsAndSymbolTables.md --- a/mlir/docs/SymbolsAndSymbolTables.md +++ b/mlir/docs/SymbolsAndSymbolTables.md @@ -137,9 +137,9 @@ different trade offs depending on the situation. A function call may directly use a `SymbolRef` as the callee, whereas a reference to a global variable might use a materialization operation so that the variable can be used in other -operations like `std.addi`. -[`llvm.mlir.addressof`](Dialects/LLVM.md/#llvmmliraddressof-mlirllvmaddressofop) is one example of -such an operation. +operations like `arith.addi`. +[`llvm.mlir.addressof`](Dialects/LLVM.md/#llvmmliraddressof-mlirllvmaddressofop) +is one example of such an operation. See the `LangRef` definition of the [`SymbolRefAttr`](Dialects/Builtin.md/#symbolrefattr) for more information diff --git a/mlir/docs/TargetLLVMIR.md b/mlir/docs/TargetLLVMIR.md --- a/mlir/docs/TargetLLVMIR.md +++ b/mlir/docs/TargetLLVMIR.md @@ -305,8 +305,8 @@ return %arg0, %arg1 : i32, i64 } func @bar() { - %0 = constant 42 : i32 - %1 = constant 17 : i64 + %0 = arith.constant 42 : i32 + %1 = arith.constant 17 : i64 %2:2 = call @foo(%0, %1) : (i32, i64) -> (i32, i64) "use_i32"(%2#0) : (i32) -> () "use_i64"(%2#1) : (i64) -> () @@ -768,7 +768,7 @@ An access to a memref with indices: ```mlir -%0 = load %m[%1,%2,%3,%4] : memref +%0 = memref.load %m[%1,%2,%3,%4] : memref ``` is transformed into the equivalent of the following code: @@ -779,27 +779,27 @@ // dynamic, extract the stride value from the descriptor. %stride1 = llvm.extractvalue[4, 0] : !llvm.struct<(ptr, ptr, i64, array<4xi64>, array<4xi64>)> -%addr1 = muli %stride1, %1 : i64 +%addr1 = arith.muli %stride1, %1 : i64 // When the stride or, in absence of explicit strides, the trailing sizes are // known statically, this value is used as a constant. The natural value of // strides is the product of all sizes following the current dimension. %stride2 = llvm.mlir.constant(32 : index) : i64 -%addr2 = muli %stride2, %2 : i64 -%addr3 = addi %addr1, %addr2 : i64 +%addr2 = arith.muli %stride2, %2 : i64 +%addr3 = arith.addi %addr1, %addr2 : i64 %stride3 = llvm.mlir.constant(8 : index) : i64 -%addr4 = muli %stride3, %3 : i64 -%addr5 = addi %addr3, %addr4 : i64 +%addr4 = arith.muli %stride3, %3 : i64 +%addr5 = arith.addi %addr3, %addr4 : i64 // Multiplication with the known unit stride can be omitted. -%addr6 = addi %addr5, %4 : i64 +%addr6 = arith.addi %addr5, %4 : i64 // If the linear offset is known to be zero, it can also be omitted. If it is // dynamic, it is extracted from the descriptor. %offset = llvm.extractvalue[2] : !llvm.struct<(ptr, ptr, i64, array<4xi64>, array<4xi64>)> -%addr7 = addi %addr6, %offset : i64 +%addr7 = arith.addi %addr6, %offset : i64 // All accesses are based on the aligned pointer. %aligned = llvm.extractvalue[1] : !llvm.struct<(ptr, ptr, i64, diff --git a/mlir/docs/Traits.md b/mlir/docs/Traits.md --- a/mlir/docs/Traits.md +++ b/mlir/docs/Traits.md @@ -300,7 +300,7 @@ `IsolatedFromAbove`: ```mlir -%result = constant 10 : i32 +%result = arith.constant 10 : i32 foo.region_op { foo.yield %result : i32 } diff --git a/mlir/docs/Tutorials/Toy/Ch-5.md b/mlir/docs/Tutorials/Toy/Ch-5.md --- a/mlir/docs/Tutorials/Toy/Ch-5.md +++ b/mlir/docs/Tutorials/Toy/Ch-5.md @@ -15,11 +15,11 @@ `Affine` for the computation heavy part of Toy, and in the [next chapter](Ch-6.md) directly target the `LLVM IR` dialect for lowering `print`. As part of this lowering, we will be lowering from the -[TensorType](../../Dialects/Builtin.md/#rankedtensortype) that `Toy` -operates on to the [MemRefType](../../Dialects/Builtin.md/#memreftype) that is -indexed via an affine loop-nest. Tensors represent an abstract value-typed +[TensorType](../../Dialects/Builtin.md/#rankedtensortype) that `Toy` +operates on to the [MemRefType](../../Dialects/Builtin.md/#memreftype) that is +indexed via an affine loop-nest. Tensors represent an abstract value-typed sequence of data, meaning that they don't live in any memory. MemRefs, on the -other hand, represent lower level buffer access, as they are concrete +other hand, represent lower level buffer access, as they are concrete references to a region of memory. # Dialect Conversions @@ -63,9 +63,9 @@ // We define the specific operations, or dialects, that are legal targets for // this lowering. In our case, we are lowering to a combination of the - // `Affine`, `MemRef` and `Standard` dialects. - target.addLegalDialect(); + // `Affine`, `Arithmetic`, `MemRef`, and `Standard` dialects. + target.addLegalDialect(); // We also define the Toy dialect as Illegal so that the conversion will fail // if any of these operations are *not* converted. Given that we actually want @@ -97,7 +97,7 @@ remapped/replaced. This is used when dealing with type conversions, as the pattern will want to operate on values of the new type but match against the old. For our lowering, this invariant will be useful as it translates from the -[TensorType](../../Dialects/Builtin.md/#rankedtensortype) currently +[TensorType](../../Dialects/Builtin.md/#rankedtensortype) currently being operated on to the [MemRefType](../../Dialects/Builtin.md/#memreftype). Let's look at a snippet of lowering the `toy.transpose` operation: @@ -241,17 +241,17 @@ ```mlir func @main() { - %cst = constant 1.000000e+00 : f64 - %cst_0 = constant 2.000000e+00 : f64 - %cst_1 = constant 3.000000e+00 : f64 - %cst_2 = constant 4.000000e+00 : f64 - %cst_3 = constant 5.000000e+00 : f64 - %cst_4 = constant 6.000000e+00 : f64 + %cst = arith.constant 1.000000e+00 : f64 + %cst_0 = arith.constant 2.000000e+00 : f64 + %cst_1 = arith.constant 3.000000e+00 : f64 + %cst_2 = arith.constant 4.000000e+00 : f64 + %cst_3 = arith.constant 5.000000e+00 : f64 + %cst_4 = arith.constant 6.000000e+00 : f64 // Allocating buffers for the inputs and outputs. - %0 = alloc() : memref<3x2xf64> - %1 = alloc() : memref<3x2xf64> - %2 = alloc() : memref<2x3xf64> + %0 = memref.alloc() : memref<3x2xf64> + %1 = memref.alloc() : memref<3x2xf64> + %2 = memref.alloc() : memref<2x3xf64> // Initialize the input buffer with the constant values. affine.store %cst, %2[0, 0] : memref<2x3xf64> @@ -275,16 +275,16 @@ affine.for %arg1 = 0 to 2 { %3 = affine.load %1[%arg0, %arg1] : memref<3x2xf64> %4 = affine.load %1[%arg0, %arg1] : memref<3x2xf64> - %5 = mulf %3, %4 : f64 + %5 = arith.mulf %3, %4 : f64 affine.store %5, %0[%arg0, %arg1] : memref<3x2xf64> } } // Print the value held by the buffer. toy.print %0 : memref<3x2xf64> - dealloc %2 : memref<2x3xf64> - dealloc %1 : memref<3x2xf64> - dealloc %0 : memref<3x2xf64> + memref.dealloc %2 : memref<2x3xf64> + memref.dealloc %1 : memref<3x2xf64> + memref.dealloc %0 : memref<3x2xf64> return } ``` @@ -299,16 +299,16 @@ ```mlir func @main() { - %cst = constant 1.000000e+00 : f64 - %cst_0 = constant 2.000000e+00 : f64 - %cst_1 = constant 3.000000e+00 : f64 - %cst_2 = constant 4.000000e+00 : f64 - %cst_3 = constant 5.000000e+00 : f64 - %cst_4 = constant 6.000000e+00 : f64 + %cst = arith.constant 1.000000e+00 : f64 + %cst_0 = arith.constant 2.000000e+00 : f64 + %cst_1 = arith.constant 3.000000e+00 : f64 + %cst_2 = arith.constant 4.000000e+00 : f64 + %cst_3 = arith.constant 5.000000e+00 : f64 + %cst_4 = arith.constant 6.000000e+00 : f64 // Allocating buffers for the inputs and outputs. - %0 = alloc() : memref<3x2xf64> - %1 = alloc() : memref<2x3xf64> + %0 = memref.alloc() : memref<3x2xf64> + %1 = memref.alloc() : memref<2x3xf64> // Initialize the input buffer with the constant values. affine.store %cst, %1[0, 0] : memref<2x3xf64> @@ -324,15 +324,15 @@ %2 = affine.load %1[%arg1, %arg0] : memref<2x3xf64> // Multiply and store into the output buffer. - %3 = mulf %2, %2 : f64 + %3 = arith.mulf %2, %2 : f64 affine.store %3, %0[%arg0, %arg1] : memref<3x2xf64> } } // Print the value held by the buffer. toy.print %0 : memref<3x2xf64> - dealloc %1 : memref<2x3xf64> - dealloc %0 : memref<3x2xf64> + memref.dealloc %1 : memref<2x3xf64> + memref.dealloc %0 : memref<3x2xf64> return } ``` diff --git a/mlir/docs/Tutorials/Toy/Ch-6.md b/mlir/docs/Tutorials/Toy/Ch-6.md --- a/mlir/docs/Tutorials/Toy/Ch-6.md +++ b/mlir/docs/Tutorials/Toy/Ch-6.md @@ -84,15 +84,17 @@ Now that the conversion target has been defined, we need to provide the patterns used for lowering. At this point in the compilation process, we have a -combination of `toy`, `affine`, and `std` operations. Luckily, the `std` and -`affine` dialects already provide the set of patterns needed to transform them -into LLVM dialect. These patterns allow for lowering the IR in multiple stages -by relying on [transitive lowering](../../../getting_started/Glossary.md/#transitive-lowering). +combination of `toy`, `affine`, `arith`, and `std` operations. Luckily, the +`affine`, `arith`, and `std` dialects already provide the set of patterns +needed to transform them into LLVM dialect. These patterns allow for lowering +the IR in multiple stages by relying on +[transitive lowering](../../../getting_started/Glossary.md/#transitive-lowering). ```c++ mlir::RewritePatternSet patterns(&getContext()); mlir::populateAffineToStdConversionPatterns(patterns, &getContext()); mlir::populateLoopToStdConversionPatterns(patterns, &getContext()); + mlir::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns); mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns); // The only remaining operation, to lower from the `toy` dialect, is the @@ -200,7 +202,7 @@ %106 = mul i64 %100, 1 %107 = add i64 %105, %106 %108 = getelementptr double, double* %103, i64 %107 - %109 = load double, double* %108 + %109 = memref.load double, double* %108 %110 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double %109) %111 = add i64 %100, 1 br label %99 @@ -322,7 +324,7 @@ [`--print-ir-after-all`](../../PassManagement.md/#ir-printing) to track the evolution of the IR throughout the pipeline. -The example code used throughout this section can be found in +The example code used throughout this section can be found in test/Examples/Toy/Ch6/llvm-lowering.mlir. So far, we have worked with primitive data types. In the diff --git a/mlir/docs/includes/img/branch_example_post_move.svg b/mlir/docs/includes/img/branch_example_post_move.svg --- a/mlir/docs/includes/img/branch_example_post_move.svg +++ b/mlir/docs/includes/img/branch_example_post_move.svg @@ -414,6 +414,6 @@ id="tspan3407" x="21.911886" y="15.884925" - style="font-size:5.64444px;fill:#008000;stroke-width:0.264583">%0 = alloc() + style="font-size:5.64444px;fill:#008000;stroke-width:0.264583">%0 = memref.alloc() diff --git a/mlir/docs/includes/img/branch_example_pre_move.svg b/mlir/docs/includes/img/branch_example_pre_move.svg --- a/mlir/docs/includes/img/branch_example_pre_move.svg +++ b/mlir/docs/includes/img/branch_example_pre_move.svg @@ -353,7 +353,7 @@ transform="translate(8.4353227,-0.28369449)">%0 = alloc()%0 = memref.alloc() %1 = alloc(%0)%1 = memref.alloc(%0)%5 = alloc(%d0)%5 = memref.alloc(%d0)%6 = alloc(%d1)%6 = memref.alloc(%d1)%1 = alloc(%0)%1 = memref.alloc(%0)(); - registry.insert(); + registry.insert(); // Add the following to include *all* MLIR Core dialects, or selectively // include what you need like above. You only need to register dialects that // will be *parsed* by the tool, not the one generated diff --git a/mlir/examples/standalone/test/Standalone/dummy.mlir b/mlir/examples/standalone/test/Standalone/dummy.mlir --- a/mlir/examples/standalone/test/Standalone/dummy.mlir +++ b/mlir/examples/standalone/test/Standalone/dummy.mlir @@ -3,7 +3,7 @@ module { // CHECK-LABEL: func @bar() func @bar() { - %0 = constant 1 : i32 + %0 = arith.constant 1 : i32 // CHECK: %{{.*}} = standalone.foo %{{.*}} : i32 %res = standalone.foo %0 : i32 return diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp @@ -16,6 +16,7 @@ #include "toy/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Pass/Pass.h" @@ -124,8 +125,8 @@ return success(); } }; -using AddOpLowering = BinaryOpLowering; -using MulOpLowering = BinaryOpLowering; +using AddOpLowering = BinaryOpLowering; +using MulOpLowering = BinaryOpLowering; //===----------------------------------------------------------------------===// // ToyToAffine RewritePatterns: Constant operations @@ -154,10 +155,12 @@ if (!valueShape.empty()) { for (auto i : llvm::seq( 0, *std::max_element(valueShape.begin(), valueShape.end()))) - constantIndices.push_back(rewriter.create(loc, i)); + constantIndices.push_back( + rewriter.create(loc, i)); } else { // This is the case of a tensor of rank 0. - constantIndices.push_back(rewriter.create(loc, 0)); + constantIndices.push_back( + rewriter.create(loc, 0)); } // The constant operation represents a multi-dimensional constant, so we @@ -171,7 +174,7 @@ // we store the element at the given index. if (dimension == valueShape.size()) { rewriter.create( - loc, rewriter.create(loc, *valueIt++), alloc, + loc, rewriter.create(loc, *valueIt++), alloc, llvm::makeArrayRef(indices)); return; } @@ -284,9 +287,9 @@ // We define the specific operations, or dialects, that are legal targets for // this lowering. In our case, we are lowering to a combination of the - // `Affine`, `MemRef` and `Standard` dialects. - target.addLegalDialect(); + // `Affine`, `Arithmetic`, `MemRef`, and `Standard` dialects. + target.addLegalDialect(); // We also define the Toy dialect as Illegal so that the conversion will fail // if any of these operations are *not* converted. Given that we actually want diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp @@ -16,6 +16,7 @@ #include "toy/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Pass/Pass.h" @@ -124,8 +125,8 @@ return success(); } }; -using AddOpLowering = BinaryOpLowering; -using MulOpLowering = BinaryOpLowering; +using AddOpLowering = BinaryOpLowering; +using MulOpLowering = BinaryOpLowering; //===----------------------------------------------------------------------===// // ToyToAffine RewritePatterns: Constant operations @@ -154,10 +155,12 @@ if (!valueShape.empty()) { for (auto i : llvm::seq( 0, *std::max_element(valueShape.begin(), valueShape.end()))) - constantIndices.push_back(rewriter.create(loc, i)); + constantIndices.push_back( + rewriter.create(loc, i)); } else { // This is the case of a tensor of rank 0. - constantIndices.push_back(rewriter.create(loc, 0)); + constantIndices.push_back( + rewriter.create(loc, 0)); } // The constant operation represents a multi-dimensional constant, so we // will need to generate a store for each of the elements. The following @@ -170,7 +173,7 @@ // we store the element at the given index. if (dimension == valueShape.size()) { rewriter.create( - loc, rewriter.create(loc, *valueIt++), alloc, + loc, rewriter.create(loc, *valueIt++), alloc, llvm::makeArrayRef(indices)); return; } @@ -283,9 +286,9 @@ // We define the specific operations, or dialects, that are legal targets for // this lowering. In our case, we are lowering to a combination of the - // `Affine`, `MemRef` and `Standard` dialects. - target.addLegalDialect(); + // `Affine`, `Arithmetic`, `MemRef`, and `Standard` dialects. + target.addLegalDialect(); // We also define the Toy dialect as Illegal so that the conversion will fail // if any of these operations are *not* converted. Given that we actually want diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp --- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp @@ -25,6 +25,7 @@ #include "toy/Passes.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" @@ -32,6 +33,7 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" @@ -73,9 +75,10 @@ // Create a loop for each of the dimensions within the shape. SmallVector loopIvs; for (unsigned i = 0, e = memRefShape.size(); i != e; ++i) { - auto lowerBound = rewriter.create(loc, 0); - auto upperBound = rewriter.create(loc, memRefShape[i]); - auto step = rewriter.create(loc, 1); + auto lowerBound = rewriter.create(loc, 0); + auto upperBound = + rewriter.create(loc, memRefShape[i]); + auto step = rewriter.create(loc, 1); auto loop = rewriter.create(loc, lowerBound, upperBound, step); for (Operation &nested : *loop.getBody()) @@ -198,6 +201,8 @@ RewritePatternSet patterns(&getContext()); populateAffineToStdConversionPatterns(patterns); populateLoopToStdConversionPatterns(patterns); + mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, + patterns); populateMemRefToLLVMConversionPatterns(typeConverter, patterns); populateStdToLLVMConversionPatterns(typeConverter, patterns); diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp @@ -16,6 +16,7 @@ #include "toy/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Pass/Pass.h" @@ -124,8 +125,8 @@ return success(); } }; -using AddOpLowering = BinaryOpLowering; -using MulOpLowering = BinaryOpLowering; +using AddOpLowering = BinaryOpLowering; +using MulOpLowering = BinaryOpLowering; //===----------------------------------------------------------------------===// // ToyToAffine RewritePatterns: Constant operations @@ -154,10 +155,12 @@ if (!valueShape.empty()) { for (auto i : llvm::seq( 0, *std::max_element(valueShape.begin(), valueShape.end()))) - constantIndices.push_back(rewriter.create(loc, i)); + constantIndices.push_back( + rewriter.create(loc, i)); } else { // This is the case of a tensor of rank 0. - constantIndices.push_back(rewriter.create(loc, 0)); + constantIndices.push_back( + rewriter.create(loc, 0)); } // The constant operation represents a multi-dimensional constant, so we @@ -171,7 +174,7 @@ // we store the element at the given index. if (dimension == valueShape.size()) { rewriter.create( - loc, rewriter.create(loc, *valueIt++), alloc, + loc, rewriter.create(loc, *valueIt++), alloc, llvm::makeArrayRef(indices)); return; } @@ -284,9 +287,9 @@ // We define the specific operations, or dialects, that are legal targets for // this lowering. In our case, we are lowering to a combination of the - // `Affine`, `MemRef` and `Standard` dialects. - target.addLegalDialect(); + // `Affine`, `Arithmetic`, `MemRef`, and `Standard` dialects. + target.addLegalDialect(); // We also define the Toy dialect as Illegal so that the conversion will fail // if any of these operations are *not* converted. Given that we actually want diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -25,6 +25,7 @@ #include "toy/Passes.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" @@ -32,6 +33,7 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" @@ -73,9 +75,10 @@ // Create a loop for each of the dimensions within the shape. SmallVector loopIvs; for (unsigned i = 0, e = memRefShape.size(); i != e; ++i) { - auto lowerBound = rewriter.create(loc, 0); - auto upperBound = rewriter.create(loc, memRefShape[i]); - auto step = rewriter.create(loc, 1); + auto lowerBound = rewriter.create(loc, 0); + auto upperBound = + rewriter.create(loc, memRefShape[i]); + auto step = rewriter.create(loc, 1); auto loop = rewriter.create(loc, lowerBound, upperBound, step); for (Operation &nested : *loop.getBody()) @@ -198,6 +201,8 @@ RewritePatternSet patterns(&getContext()); populateAffineToStdConversionPatterns(patterns); populateLoopToStdConversionPatterns(patterns); + mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, + patterns); populateMemRefToLLVMConversionPatterns(typeConverter, patterns); populateStdToLLVMConversionPatterns(typeConverter, patterns); diff --git a/mlir/include/mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h b/mlir/include/mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h @@ -0,0 +1,28 @@ +//===- ArithmeticToLLVM.h - Arith to LLVM dialect conversion ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARITHMETICTOLLVM_ARITHMETICTOLLVM_H +#define MLIR_CONVERSION_ARITHMETICTOLLVM_ARITHMETICTOLLVM_H + +#include + +namespace mlir { + +class LLVMTypeConverter; +class RewritePatternSet; +class Pass; + +namespace arith { +void populateArithmeticToLLVMConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); + +std::unique_ptr createConvertArithmeticToLLVMPass(); +} // end namespace arith +} // end namespace mlir + +#endif // MLIR_CONVERSION_ARITHMETICTOLLVM_ARITHMETICTOLLVM_H diff --git a/mlir/include/mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h b/mlir/include/mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h @@ -0,0 +1,28 @@ +//===- ArithmeticToSPIRV.h - Convert Arith to SPIRV dialect -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARITHMETICTOSPIRV_ARITHMETICTOSPIRV_H +#define MLIR_CONVERSION_ARITHMETICTOSPIRV_ARITHMETICTOSPIRV_H + +#include + +namespace mlir { + +class SPIRVTypeConverter; +class RewritePatternSet; +class Pass; + +namespace arith { +void populateArithmeticToSPIRVPatterns(SPIRVTypeConverter &typeConverter, + RewritePatternSet &patterns); + +std::unique_ptr createConvertArithmeticToSPIRVPass(); +} // end namespace arith +} // end namespace mlir + +#endif // MLIR_CONVERSION_ARITHMETICTOSPIRV_ARITHMETICTOSPIRV_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -10,6 +10,8 @@ #define MLIR_CONVERSION_PASSES_H #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" +#include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h" #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -39,10 +39,10 @@ %d0 = <...> %d1 = <...> %s0 = <...> - %0 = constant 2 : index - %1 = muli %0, %d1 - %2 = addi %d0, %1 - %r = addi %2, %s0 + %0 = arith.constant 2 : index + %1 = arith.muli %0, %d1 + %2 = arith.addi %d0, %1 + %r = arith.addi %2, %s0 ``` #### Input invariant @@ -74,6 +74,40 @@ ]; } +//===----------------------------------------------------------------------===// +// ArithmeticToLLVM +//===----------------------------------------------------------------------===// + +def ConvertArithmeticToLLVM : FunctionPass<"convert-arith-to-llvm"> { + let summary = "Convert Arithmetic dialect to LLVM dialect"; + let description = [{ + This pass converts supported Arithmetic ops to LLVM dialect instructions. + }]; + let constructor = "mlir::arith::createConvertArithmeticToLLVMPass()"; + let dependentDialects = ["LLVM::LLVMDialect"]; + let options = [ + Option<"indexBitwidth", "index-bitwidth", "unsigned", + /*default=kDeriveIndexBitwidthFromDataLayout*/"0", + "Bitwidth of the index type, 0 to use size of machine word">, + ]; +} + +//===----------------------------------------------------------------------===// +// ArithmeticToSPIRV +//===----------------------------------------------------------------------===// + +def ConvertArithmeticToSPIRV : FunctionPass<"convert-arith-to-spirv"> { + let summary = "Convert Arithmetic dialect to SPIR-V dialect"; + let constructor = "mlir::arith::createConvertArithmeticToSPIRVPass()"; + let dependentDialects = ["spirv::SPIRVDialect"]; + let options = [ + Option<"emulateNon32BitScalarTypes", "emulate-non-32-bit-scalar-types", + "bool", /*default=*/"true", + "Emulate non-32-bit scalar types with 32-bit ones if " + "missing native support"> + ]; +} + //===----------------------------------------------------------------------===// // AsyncToLLVM //===----------------------------------------------------------------------===// @@ -86,7 +120,10 @@ API to execute them. }]; let constructor = "mlir::createConvertAsyncToLLVMPass()"; - let dependentDialects = ["LLVM::LLVMDialect"]; + let dependentDialects = [ + "arith::ArithmeticDialect", + "LLVM::LLVMDialect", + ]; } //===----------------------------------------------------------------------===// @@ -106,11 +143,7 @@ def ConvertComplexToStandard : FunctionPass<"convert-complex-to-standard"> { let summary = "Convert Complex dialect to standard dialect"; let constructor = "mlir::createConvertComplexToStandardPass()"; - let dependentDialects = [ - "complex::ComplexDialect", - "math::MathDialect", - "StandardOpsDialect" - ]; + let dependentDialects = ["math::MathDialect"]; } //===----------------------------------------------------------------------===// @@ -136,7 +169,11 @@ def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> { let summary = "Generate NVVM operations for gpu operations"; let constructor = "mlir::createLowerGpuOpsToNVVMOpsPass()"; - let dependentDialects = ["NVVM::NVVMDialect", "memref::MemRefDialect"]; + let dependentDialects = [ + "memref::MemRefDialect", + "NVVM::NVVMDialect", + "StandardOpsDialect", + ]; let options = [ Option<"indexBitwidth", "index-bitwidth", "unsigned", /*default=kDeriveIndexBitwidthFromDataLayout*/"0", @@ -252,7 +289,11 @@ This pass converts supported Math ops to libm calls. }]; let constructor = "mlir::createConvertMathToLibmPass()"; - let dependentDialects = ["StandardOpsDialect", "vector::VectorDialect"]; + let dependentDialects = [ + "arith::ArithmeticDialect", + "StandardOpsDialect", + "vector::VectorDialect", + ]; } //===----------------------------------------------------------------------===// @@ -448,7 +489,6 @@ let dependentDialects = [ "StandardOpsDialect", "scf::SCFDialect", - "tensor::TensorDialect" ]; } @@ -583,7 +623,11 @@ def TosaToStandard : Pass<"tosa-to-standard"> { let summary = "Lower TOSA to the Standard dialect"; - let dependentDialects = ["StandardOpsDialect", "tensor::TensorDialect"]; + let dependentDialects = [ + "arith::ArithmeticDialect", + "StandardOpsDialect", + "tensor::TensorDialect", + ]; let description = [{ Pass that converts TOSA operations to the equivalent operations using the operations in the Standard dialect. diff --git a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h --- a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h +++ b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h @@ -37,7 +37,7 @@ /// affine.for %I = 0 to 9 { /// %dim = dim %A, 0 : memref /// %add = affine.apply %I + %a -/// %cmp = cmpi "slt", %add, %dim : index +/// %cmp = arith.cmpi "slt", %add, %dim : index /// scf.if %cmp { /// %vec_2d = load %1[%I] : memref<9xvector<17x15xf32>> /// vector.transfer_write %vec_2d, %A[%add, %b, %c] : diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -23,6 +23,7 @@ let name = "affine"; let cppNamespace = "mlir"; let hasConstantMaterializer = 1; + let dependentDialects = ["arith::ArithmeticDialect"]; } // Base class for Affine dialect ops. @@ -201,7 +202,7 @@ %sum = affine.for %i = 0 to 10 step 2 iter_args(%sum_iter = %sum_0) -> (f32) { %t = affine.load %buffer[%i] : memref<1024xf32> - %sum_next = addf %sum_iter, %t : f32 + %sum_next = arith.addf %sum_iter, %t : f32 // Yield current iteration sum to next iteration %sum_iter or to %sum // if final iteration. affine.yield %sum_next : f32 @@ -213,8 +214,8 @@ ```mlir %res:2 = affine.for %i = 0 to 128 iter_args(%arg0 = %init0, %arg1 = %init1) -> (index, index) { - %y0 = addi %arg0, %c1 : index - %y1 = addi %arg1, %c2 : index + %y0 = arith.addi %arg0, %c1 : index + %y1 = arith.addi %arg1, %c2 : index affine.yield %y0, %y1 : index, index } ``` @@ -656,7 +657,7 @@ %0 = affine.parallel (%kx, %ky) = (0, 0) to (2, 2) reduce ("addf") { %1 = affine.load %D[%x + %kx, %y + %ky] : memref<100x100xf32> %2 = affine.load %K[%kx, %ky] : memref<3x3xf32> - %3 = mulf %1, %2 : f32 + %3 = arith.mulf %1, %2 : f32 affine.yield %3 : f32 } affine.store %0, O[%x, %y] : memref<98x98xf32> diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td --- a/mlir/include/mlir/Dialect/Affine/Passes.td +++ b/mlir/include/mlir/Dialect/Affine/Passes.td @@ -112,7 +112,7 @@ affine.for %i1 = 0 to 10 { affine.store %cf7, %m[%i0, %i1] : memref<10x10xf32> %v0 = affine.load %m[%i0, %i1] : memref<10x10xf32> - %v1 = addf %v0, %v0 : f32 + %v1 = arith.addf %v0, %v0 : f32 } } return %m : memref<10x10xf32> @@ -129,7 +129,7 @@ affine.for %arg0 = 0 to 10 { affine.for %arg1 = 0 to 10 { affine.store %cst, %0[%arg0, %arg1] : memref<10x10xf32> - %1 = addf %cst, %cst : f32 + %1 = arith.addf %cst, %cst : f32 } } return %0 : memref<10x10xf32> diff --git a/mlir/include/mlir/Dialect/Arithmetic/CMakeLists.txt b/mlir/include/mlir/Dialect/Arithmetic/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Arithmetic/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Arithmetic/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h --- a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h @@ -10,6 +10,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/VectorInterfaces.h" @@ -33,6 +34,64 @@ #define GET_OP_CLASSES #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.h.inc" +namespace mlir { +namespace arith { + +/// Specialization of `arith.constant` op that returns an integer value. +class ConstantIntOp : public arith::ConstantOp { +public: + using arith::ConstantOp::ConstantOp; + + /// Build a constant int op that produces an integer of the specified width. + static void build(OpBuilder &builder, OperationState &result, int64_t value, + unsigned width); + + /// Build a constant int op that produces an integer of the specified type, + /// which must be an integer type. + static void build(OpBuilder &builder, OperationState &result, int64_t value, + Type type); + + inline int64_t value() { + return arith::ConstantOp::value().cast().getInt(); + } + + static bool classof(Operation *op); +}; + +/// Specialization of `arith.constant` op that returns a floating point value. +class ConstantFloatOp : public arith::ConstantOp { +public: + using arith::ConstantOp::ConstantOp; + + /// Build a constant float op that produces a float of the specified type. + static void build(OpBuilder &builder, OperationState &result, + const APFloat &value, FloatType type); + + inline APFloat value() { + return arith::ConstantOp::value().cast().getValue(); + } + + static bool classof(Operation *op); +}; + +/// Specialization of `arith.constant` op that returns an integer of index type. +class ConstantIndexOp : public arith::ConstantOp { +public: + using arith::ConstantOp::ConstantOp; + + /// Build a constant int op that produces an index. + static void build(OpBuilder &builder, OperationState &result, int64_t value); + + inline int64_t value() { + return arith::ConstantOp::value().cast().getInt(); + } + + static bool classof(Operation *op); +}; + +} // end namespace arith +} // end namespace mlir + //===----------------------------------------------------------------------===// // Utility Functions //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td @@ -20,6 +20,8 @@ ops, bitwise and shift ops, cast ops, and compare ops. Operations in this dialect also accept vectors and tensors of integers or floats. }]; + + let hasConstantMaterializer = 1; } // The predicate indicates the type of the comparison to perform: diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -13,6 +13,7 @@ include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/VectorInterfaces.td" +include "mlir/IR/OpAsmInterface.td" // Base class for Arithmetic dialect ops. Ops in this dialect have no side // effects and can be applied element-wise to vectors and tensors. @@ -119,12 +120,14 @@ //===----------------------------------------------------------------------===// def Arith_ConstantOp : Op, + TypesMatchWith< + "result and attribute have the same type", "value", "result", "$_self">]> { let summary = "integer or floating point constant"; let description = [{ - The `const` operation produces an SSA value equal to some integer or + The `constant` operation produces an SSA value equal to some integer or floating-point constant specified by an attribute. This is the way MLIR forms simple integer and floating point constants. @@ -149,6 +152,12 @@ [{ build($_builder, $_state, type, value); }]>, ]; + let extraClassDeclaration = [{ + /// Whether the constant op can be constructed with a particular value and + /// type. + static bool isBuildableWith(Attribute value, Type type); + }]; + let hasFolder = 1; let assemblyFormat = "attr-dict $value"; } @@ -351,13 +360,13 @@ ```mlir // Scalar signed integer division remainder. - %a = remsi %b, %c : i64 + %a = arith.remsi %b, %c : i64 // SIMD vector element-wise division remainder. - %f = remsi %g, %h : vector<4xi32> + %f = arith.remsi %g, %h : vector<4xi32> // Tensor element-wise integer division remainder. - %x = remsi %y, %z : tensor<4x?xi8> + %x = arith.remsi %y, %z : tensor<4x?xi8> ``` }]; let hasFolder = 1; @@ -717,10 +726,10 @@ ```mlir %1 = arith.constant 21 : i5 // %1 is 0b10101 - %2 = trunci %1 : i5 to i4 // %2 is 0b0101 - %3 = trunci %1 : i5 to i3 // %3 is 0b101 + %2 = arith.trunci %1 : i5 to i4 // %2 is 0b0101 + %3 = arith.trunci %1 : i5 to i3 // %3 is 0b101 - %5 = trunci %0 : vector<2 x i32> to vector<2 x i16> + %5 = arith.trunci %0 : vector<2 x i32> to vector<2 x i16> ``` }]; @@ -803,7 +812,14 @@ // IndexCastOp //===----------------------------------------------------------------------===// -def Arith_IndexCastOp : Arith_IToICastOp<"index_cast"> { +// Index cast can convert between memrefs of signless integers and indices too. +def IndexCastTypeConstraint : TypeConstraint.predicate]>, + "signless-integer-like or memref of signless-integer">; + +def Arith_IndexCastOp : Arith_CastOp<"index_cast", IndexCastTypeConstraint, + IndexCastTypeConstraint> { let summary = "cast between index and integer types"; let description = [{ Casts between scalar or vector integers and corresponding 'index' scalar or @@ -820,8 +836,15 @@ // BitcastOp //===----------------------------------------------------------------------===// -def Arith_BitcastOp : Arith_CastOp<"bitcast", SignlessIntegerOrFloatLike, - SignlessIntegerOrFloatLike> { +// Bitcast can convert between memrefs of signless integers, indices, and +// floats too. +def BitcastTypeConstraint : TypeConstraint.predicate]>, + "signless-integer-or-float-like or memref of signless-integer or float">; + +def Arith_BitcastOp : Arith_CastOp<"bitcast", BitcastTypeConstraint, + BitcastTypeConstraint> { let summary = "bitcast between values of equal bit width"; let description = [{ Bitcast an integer or floating point value to an integer or floating point @@ -927,10 +950,10 @@ let extraClassDeclaration = [{ static StringRef getPredicateAttrName() { return "predicate"; } - static CmpIPredicate getPredicateByName(StringRef name); + static arith::CmpIPredicate getPredicateByName(StringRef name); - CmpIPredicate getPredicate() { - return (CmpIPredicate) (*this)->getAttrOfType( + arith::CmpIPredicate getPredicate() { + return (arith::CmpIPredicate) (*this)->getAttrOfType( getPredicateAttrName()).getInt(); } }]; @@ -983,10 +1006,10 @@ let extraClassDeclaration = [{ static StringRef getPredicateAttrName() { return "predicate"; } - static CmpFPredicate getPredicateByName(StringRef name); + static arith::CmpFPredicate getPredicateByName(StringRef name); - CmpFPredicate getPredicate() { - return (CmpFPredicate) (*this)->getAttrOfType( + arith::CmpFPredicate getPredicate() { + return (arith::CmpFPredicate) (*this)->getAttrOfType( getPredicateAttrName()).getInt(); } }]; diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Arithmetic/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Arithmetic) +add_public_tablegen_target(MLIRArithmeticTransformsIncGen) + +add_mlir_doc(Passes ArithmeticPasses ./ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h @@ -0,0 +1,42 @@ +//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARITHMETIC_TRANSFORMS_PASSES_H_ +#define MLIR_DIALECT_ARITHMETIC_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Bufferize.h" + +namespace mlir { +namespace arith { + +/// Add patterns to bufferize Arithmetic ops. +void populateArithmeticBufferizePatterns(BufferizeTypeConverter &typeConverter, + RewritePatternSet &patterns); + +/// Create a pass to bufferize Arithmetic ops. +std::unique_ptr createArithmeticBufferizePass(); + +/// Add patterns to expand Arithmetic ops for LLVM lowering. +void populateArithmeticExpandOpsPatterns(RewritePatternSet &patterns); + +/// Create a pass to legalize Arithmetic ops for LLVM lowering. +std::unique_ptr createArithmeticExpandOpsPass(); + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/Arithmetic/Transforms/Passes.h.inc" + +} // end namespace arith +} // end namespace mlir + +#endif // MLIR_DIALECT_ARITHMETIC_TRANSFORMS_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td @@ -0,0 +1,26 @@ +//===-- Passes.td - Arithmetic pass definition file --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARITHMETIC_TRANSFORMS_PASSES +#define MLIR_DIALECT_ARITHMETIC_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" + +def ArithmeticBufferize : FunctionPass<"arith-bufferize"> { + let summary = "Bufferize Arithmetic dialect ops."; + let constructor = "mlir::arith::createArithmeticBufferizePass()"; + let dependentDialects = ["memref::MemRefDialect"]; +} + +def ArithmeticExpandOps : FunctionPass<"arith-expand"> { + let summary = "Legalize Arithmetic ops to be convertible to LLVM."; + let constructor = "mlir::arith::createArithmeticExpandOpsPass()"; + let dependentDialects = ["StandardOpsDialect"]; +} + +#endif // MLIR_DIALECT_ARITHMETIC_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td --- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td @@ -15,7 +15,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" -include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td" +include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td" include "mlir/Dialect/ArmSVE/ArmSVEOpBase.td" //===----------------------------------------------------------------------===// @@ -460,24 +460,24 @@ ``` }]; let arguments = (ins - CmpFPredicateAttr:$predicate, + Arith_CmpFPredicateAttr:$predicate, ScalableVectorOf<[AnyFloat]>:$lhs, ScalableVectorOf<[AnyFloat]>:$rhs // TODO: This should support a simple scalar ); let results = (outs ScalableVectorOf<[I1]>:$result); let builders = [ - OpBuilder<(ins "CmpFPredicate":$predicate, "Value":$lhs, + OpBuilder<(ins "arith::CmpFPredicate":$predicate, "Value":$lhs, "Value":$rhs), [{ buildScalableCmpFOp($_builder, $_state, predicate, lhs, rhs); }]>]; let extraClassDeclaration = [{ static StringRef getPredicateAttrName() { return "predicate"; } - static CmpFPredicate getPredicateByName(StringRef name); + static arith::CmpFPredicate getPredicateByName(StringRef name); - CmpFPredicate getPredicate() { - return (CmpFPredicate)(*this)->getAttrOfType( + arith::CmpFPredicate getPredicate() { + return (arith::CmpFPredicate) (*this)->getAttrOfType( getPredicateAttrName()).getInt(); } }]; @@ -520,24 +520,24 @@ }]; let arguments = (ins - CmpIPredicateAttr:$predicate, + Arith_CmpIPredicateAttr:$predicate, ScalableVectorOf<[I8, I16, I32, I64]>:$lhs, ScalableVectorOf<[I8, I16, I32, I64]>:$rhs ); let results = (outs ScalableVectorOf<[I1]>:$result); let builders = [ - OpBuilder<(ins "CmpIPredicate":$predicate, "Value":$lhs, + OpBuilder<(ins "arith::CmpIPredicate":$predicate, "Value":$lhs, "Value":$rhs), [{ buildScalableCmpIOp($_builder, $_state, predicate, lhs, rhs); }]>]; let extraClassDeclaration = [{ static StringRef getPredicateAttrName() { return "predicate"; } - static CmpIPredicate getPredicateByName(StringRef name); + static arith::CmpIPredicate getPredicateByName(StringRef name); - CmpIPredicate getPredicate() { - return (CmpIPredicate)(*this)->getAttrOfType( + arith::CmpIPredicate getPredicate() { + return (arith::CmpIPredicate) (*this)->getAttrOfType( getPredicateAttrName()).getInt(); } }]; diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td --- a/mlir/include/mlir/Dialect/Async/Passes.td +++ b/mlir/include/mlir/Dialect/Async/Passes.td @@ -32,7 +32,11 @@ "The minimum task size for sharding parallel operation."> ]; - let dependentDialects = ["async::AsyncDialect", "scf::SCFDialect"]; + let dependentDialects = [ + "arith::ArithmeticDialect", + "async::AsyncDialect", + "scf::SCFDialect" + ]; } def AsyncToAsyncRuntime : Pass<"async-to-async-runtime", "ModuleOp"> { diff --git a/mlir/include/mlir/Dialect/Complex/IR/Complex.h b/mlir/include/mlir/Dialect/Complex/IR/Complex.h --- a/mlir/include/mlir/Dialect/Complex/IR/Complex.h +++ b/mlir/include/mlir/Dialect/Complex/IR/Complex.h @@ -9,6 +9,8 @@ #ifndef MLIR_DIALECT_COMPLEX_IR_COMPLEX_H_ #define MLIR_DIALECT_COMPLEX_IR_COMPLEX_H_ +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td @@ -18,6 +18,9 @@ The complex dialect is intended to hold complex numbers creation and arithmetic ops. }]; + + let dependentDialects = ["arith::ArithmeticDialect", "StandardOpsDialect"]; + let hasConstantMaterializer = 1; } #endif // COMPLEX_BASE diff --git a/mlir/include/mlir/Dialect/GPU/GPUBase.td b/mlir/include/mlir/Dialect/GPU/GPUBase.td --- a/mlir/include/mlir/Dialect/GPU/GPUBase.td +++ b/mlir/include/mlir/Dialect/GPU/GPUBase.td @@ -51,6 +51,8 @@ /// space. static unsigned getPrivateAddressSpace() { return 5; } }]; + + let dependentDialects = ["arith::ArithmeticDialect"]; } def GPU_AsyncToken : DialectType< diff --git a/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/GPUDialect.h --- a/mlir/include/mlir/Dialect/GPU/GPUDialect.h +++ b/mlir/include/mlir/Dialect/GPU/GPUDialect.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_GPU_GPUDIALECT_H #define MLIR_DIALECT_GPU_GPUDIALECT_H +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/DLTI/Traits.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -627,7 +627,7 @@ %1 = "gpu.all_reduce"(%0) ({}) { op = "add" } : (f32) -> (f32) %2 = "gpu.all_reduce"(%0) ({ ^bb(%lhs : f32, %rhs : f32): - %sum = addf %lhs, %rhs : f32 + %sum = arith.addf %lhs, %rhs : f32 "gpu.yield"(%sum) : (f32) -> () }) : (f32) -> (f32) ``` diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -33,11 +33,16 @@ }]; let cppNamespace = "::mlir::linalg"; let dependentDialects = [ - "AffineDialect", "math::MathDialect", "memref::MemRefDialect", - "StandardOpsDialect", "tensor::TensorDialect" + "arith::ArithmeticDialect", + "AffineDialect", + "math::MathDialect", + "memref::MemRefDialect", + "StandardOpsDialect", + "tensor::TensorDialect", ]; let hasCanonicalizer = 1; let hasOperationAttrVerify = 1; + let hasConstantMaterializer = 1; let extraClassDeclaration = [{ /// Attribute name used to to memoize indexing maps for named ops. constexpr const static ::llvm::StringLiteral diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -478,8 +478,8 @@ outs(%C : memref) {other-optional-attributes} { ^bb0(%a: f32, %b: f32, %c: f32) : - %d = mulf %a, %b: f32 - %e = addf %c, %d: f32 + %d = arith.mulf %a, %b: f32 + %e = arith.addf %c, %d: f32 linalg.yield %e : f32 } ``` @@ -501,8 +501,8 @@ %a = load %A[%m, %k] : memref %b = load %B[%k, %n] : memref %c = load %C[%m, %n] : memref - %d = mulf %a, %b: f32 - %e = addf %c, %d: f32 + %d = arith.mulf %a, %b: f32 + %e = arith.addf %c, %d: f32 store %e, %C[%m, %n] : memref } } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h @@ -10,6 +10,7 @@ #define MLIR_DIALECT_LINALG_LINALGTYPES_H_ #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -143,7 +143,7 @@ let dependentDialects = [ "linalg::LinalgDialect", "AffineDialect", - "memref::MemRefDialect" + "memref::MemRefDialect", ]; } diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -271,7 +271,7 @@ /// to /// /// %iv = %lb + %procId * %step - /// %cond = cmpi "slt", %iv, %ub + /// %cond = arith.cmpi "slt", %iv, %ub /// scf.if %cond { /// ... /// } diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_MEMREF_IR_MEMREF_H_ #define MLIR_DIALECT_MEMREF_IR_MEMREF_H_ +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/Dialect.h" diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td @@ -19,7 +19,7 @@ manipulation ops, which are not strongly associated with any particular other dialect or domain abstraction. }]; - let dependentDialects = ["tensor::TensorDialect"]; + let dependentDialects = ["arith::ArithmeticDialect", "tensor::TensorDialect"]; let hasConstantMaterializer = 1; } diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -156,7 +156,7 @@ omp.wsloop (%i1, %i2) : index = (%c0, %c0) to (%c10, %c10) step (%c1, %c1) { %a = load %arrA[%i1, %i2] : memref %b = load %arrB[%i1, %i2] : memref - %sum = addf %a, %b : f32 + %sum = arith.addf %a, %b : f32 store %sum, %arrC[%i1, %i2] : memref omp.yield } diff --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td --- a/mlir/include/mlir/Dialect/SCF/Passes.td +++ b/mlir/include/mlir/Dialect/SCF/Passes.td @@ -94,18 +94,18 @@ ```mlir # Before: scf.for %i = %c0 to %arg1 step %c1 { - %0 = addi %arg2, %arg2 : i32 + %0 = arith.addi %arg2, %arg2 : i32 memref.store %0, %arg0[%i] : memref } # After: %0 = scf.while (%i = %c0) : (index) -> index { - %1 = cmpi slt, %i, %arg1 : index + %1 = arith.cmpi slt, %i, %arg1 : index scf.condition(%1) %i : index } do { ^bb0(%i: index): // no predecessors - %1 = addi %i, %c1 : index - %2 = addi %arg2, %arg2 : i32 + %1 = arith.addi %i, %c1 : index + %2 = arith.addi %arg2, %arg2 : i32 memref.store %2, %arg0[%i] : memref scf.yield %1 : index } diff --git a/mlir/include/mlir/Dialect/SCF/SCF.h b/mlir/include/mlir/Dialect/SCF/SCF.h --- a/mlir/include/mlir/Dialect/SCF/SCF.h +++ b/mlir/include/mlir/Dialect/SCF/SCF.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_SCF_H_ #define MLIR_DIALECT_SCF_H_ +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" @@ -86,9 +87,9 @@ /// expect the body building functions to return their current value. /// The built nested scf::For are captured in `capturedLoops` when non-null. LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, - ValueRange ubs, ValueRange steps, - function_ref - bodyBuilder = nullptr); + ValueRange ubs, ValueRange steps, + function_ref + bodyBuilder = nullptr); } // end namespace scf } // end namespace mlir diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -20,6 +20,7 @@ def SCF_Dialect : Dialect { let name = "scf"; let cppNamespace = "::mlir::scf"; + let dependentDialects = ["arith::ArithmeticDialect"]; } // Base class for SCF dialect ops. @@ -170,7 +171,7 @@ %sum = scf.for %iv = %lb to %ub step %step iter_args(%sum_iter = %sum_0) -> (f32) { %t = load %buffer[%iv] : memref<1024xf32> - %sum_next = addf %sum_iter, %t : f32 + %sum_next = arith.addf %sum_iter, %t : f32 // Yield current iteration sum to next iteration %sum_iter or to %sum // if final iteration. scf.yield %sum_next : f32 @@ -194,9 +195,9 @@ %sum = scf.for %iv = %lb to %ub step %step iter_args(%sum_iter = %sum_0) -> (f32) { %t = load %buffer[%iv] : memref<1024xf32> - %cond = cmpf "ugt", %t, %c0 : f32 + %cond = arith.cmpf "ugt", %t, %c0 : f32 %sum_next = scf.if %cond -> (f32) { - %new_sum = addf %sum_iter, %t : f32 + %new_sum = arith.addf %sum_iter, %t : f32 scf.yield %new_sum : f32 } else { scf.yield %sum_iter : f32 @@ -451,7 +452,7 @@ %elem_to_reduce = load %buffer[%iv] : memref<100xf32> scf.reduce(%elem_to_reduce) : f32 { ^bb0(%lhs : f32, %rhs: f32): - %res = addf %lhs, %rhs : f32 + %res = arith.addf %lhs, %rhs : f32 scf.reduce.return %res : f32 } } @@ -519,7 +520,7 @@ %operand = constant 1.0 : f32 scf.reduce(%operand) : f32 { ^bb0(%lhs : f32, %rhs: f32): - %res = addf %lhs, %rhs : f32 + %res = arith.addf %lhs, %rhs : f32 scf.reduce.return %res : f32 } ``` diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h --- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h +++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h @@ -14,6 +14,7 @@ #ifndef MLIR_SHAPE_IR_SHAPE_H #define MLIR_SHAPE_IR_SHAPE_H +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td @@ -35,7 +35,7 @@ }]; let cppNamespace = "::mlir::shape"; - let dependentDialects = ["tensor::TensorDialect"]; + let dependentDialects = ["arith::ArithmeticDialect", "tensor::TensorDialect"]; let hasConstantMaterializer = 1; let hasOperationAttrVerify = 1; diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -43,8 +43,8 @@ ins(%arga, %argb: tensor, tensor) outs(%argx: tensor) { ^bb(%a: f64, %b: f64, %x: f64): - %0 = mulf %a, %b : f64 - %1 = addf %x, %0 : f64 + %0 = arith.mulf %a, %b : f64 + %1 = arith.addf %x, %0 : f64 linalg.yield %1 : f64 } -> tensor return %0 : tensor @@ -54,6 +54,7 @@ let constructor = "mlir::createSparsificationPass()"; let dependentDialects = [ "AffineDialect", + "arith::ArithmeticDialect", "LLVM::LLVMDialect", "memref::MemRefDialect", "scf::SCFDialect", @@ -103,6 +104,7 @@ }]; let constructor = "mlir::createSparseTensorConversionPass()"; let dependentDialects = [ + "arith::ArithmeticDialect", "LLVM::LLVMDialect", "memref::MemRefDialect", "scf::SCFDialect", diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_STANDARDOPS_IR_OPS_H #define MLIR_DIALECT_STANDARDOPS_IR_OPS_H +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" @@ -41,77 +42,15 @@ #include "mlir/Dialect/StandardOps/IR/OpsDialect.h.inc" namespace mlir { -/// This is a refinement of the "constant" op for the case where it is -/// returning a float value of FloatType. -/// -/// %1 = "std.constant"(){value: 42.0} : bf16 -/// -class ConstantFloatOp : public ConstantOp { -public: - using ConstantOp::ConstantOp; - - /// Builds a constant float op producing a float of the specified type. - static void build(OpBuilder &builder, OperationState &result, - const APFloat &value, FloatType type); - - APFloat getValue() { - return (*this)->getAttrOfType("value").getValue(); - } - - static bool classof(Operation *op); -}; - -/// This is a refinement of the "constant" op for the case where it is -/// returning an integer value of IntegerType. -/// -/// %1 = "std.constant"(){value: 42} : i32 -/// -class ConstantIntOp : public ConstantOp { -public: - using ConstantOp::ConstantOp; - /// Build a constant int op producing an integer of the specified width. - static void build(OpBuilder &builder, OperationState &result, int64_t value, - unsigned width); - - /// Build a constant int op producing an integer with the specified type, - /// which must be an integer type. - static void build(OpBuilder &builder, OperationState &result, int64_t value, - Type type); - - int64_t getValue() { - return (*this)->getAttrOfType("value").getInt(); - } - - static bool classof(Operation *op); -}; - -/// This is a refinement of the "constant" op for the case where it is -/// returning an integer value of Index type. -/// -/// %1 = "std.constant"(){value: 99} : () -> index -/// -class ConstantIndexOp : public ConstantOp { -public: - using ConstantOp::ConstantOp; - - /// Build a constant int op producing an index. - static void build(OpBuilder &builder, OperationState &result, int64_t value); - - int64_t getValue() { - return (*this)->getAttrOfType("value").getInt(); - } - - static bool classof(Operation *op); -}; /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer /// comparison predicates. -bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs, +bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs); /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point /// comparison predicates. -bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs, +bool applyCmpPredicate(arith::CmpFPredicate predicate, const APFloat &lhs, const APFloat &rhs); /// Returns the identity value attribute associated with an AtomicRMWKind op. diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -25,6 +25,7 @@ def StandardOps_Dialect : Dialect { let name = "std"; let cppNamespace = "::mlir"; + let dependentDialects = ["arith::ArithmeticDialect"]; let hasConstantMaterializer = 1; } @@ -182,138 +183,6 @@ [DeclareOpInterfaceMethods])>, Arguments<(ins FloatLike:$a, FloatLike:$b, FloatLike:$c)>; -//===----------------------------------------------------------------------===// -// AbsFOp -//===----------------------------------------------------------------------===// - -def AbsFOp : FloatUnaryOp<"absf"> { - let summary = "floating point absolute-value operation"; - let description = [{ - The `absf` operation computes the absolute value. It takes one operand and - returns one result of the same type. This type may be a float scalar type, - a vector whose element type is float, or a tensor of floats. - - Example: - - ```mlir - // Scalar absolute value. - %a = absf %b : f64 - - // SIMD vector element-wise absolute value. - %f = absf %g : vector<4xf32> - - // Tensor element-wise absolute value. - %x = absf %y : tensor<4x?xf8> - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// AddFOp -//===----------------------------------------------------------------------===// - -def AddFOp : FloatBinaryOp<"addf"> { - let summary = "floating point addition operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.addf` ssa-use `,` ssa-use `:` type - ``` - - The `addf` operation takes two operands and returns one result, each of - these is required to be the same type. This type may be a floating point - scalar type, a vector whose element type is a floating point type, or a - floating point tensor. - - Example: - - ```mlir - // Scalar addition. - %a = addf %b, %c : f64 - - // SIMD vector addition, e.g. for Intel SSE. - %f = addf %g, %h : vector<4xf32> - - // Tensor addition. - %x = addf %y, %z : tensor<4x?xbf16> - ``` - - TODO: In the distant future, this will accept optional attributes for fast - math, contraction, rounding mode, and other controls. - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// AddIOp -//===----------------------------------------------------------------------===// - -def AddIOp : IntBinaryOp<"addi", [Commutative]> { - let summary = "integer addition operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.addi` ssa-use `,` ssa-use `:` type - ``` - - The `addi` operation takes two operands and returns one result, each of - these is required to be the same type. This type may be an integer scalar - type, a vector whose element type is integer, or a tensor of integers. It - has no standard attributes. - - Example: - - ```mlir - // Scalar addition. - %a = addi %b, %c : i64 - - // SIMD vector element-wise addition, e.g. for Intel SSE. - %f = addi %g, %h : vector<4xi32> - - // Tensor element-wise addition. - %x = addi %y, %z : tensor<4x?xi8> - ``` - }]; - let hasFolder = 1; - let hasCanonicalizer = 1; -} - -//===----------------------------------------------------------------------===// -// AndOp -//===----------------------------------------------------------------------===// - -def AndOp : IntBinaryOp<"and", [Commutative]> { - let summary = "integer binary and"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.and` ssa-use `,` ssa-use `:` type - ``` - - The `and` operation takes two operands and returns one result, each of these - is required to be the same type. This type may be an integer scalar type, a - vector whose element type is integer, or a tensor of integers. It has no - standard attributes. - - Example: - - ```mlir - // Scalar integer bitwise and. - %a = and %b, %c : i64 - - // SIMD vector element-wise bitwise integer and. - %f = and %g, %h : vector<4xi32> - - // Tensor element-wise bitwise integer and. - %x = and %y, %z : tensor<4x?xi8> - ``` - }]; - let hasFolder = 1; -} - //===----------------------------------------------------------------------===// // AssertOp //===----------------------------------------------------------------------===// @@ -413,7 +282,7 @@ %x = generic_atomic_rmw %I[%i] : memref<10xf32> { ^bb0(%current_value : f32): %c1 = constant 1.0 : f32 - %inc = addf %c1, %current_value : f32 + %inc = arith.addf %c1, %current_value : f32 atomic_yield %inc : f32 } ``` @@ -456,32 +325,6 @@ let assemblyFormat = "$result attr-dict `:` type($result)"; } -//===----------------------------------------------------------------------===// -// BitcastOp -//===----------------------------------------------------------------------===// - -def BitcastOp : ArithmeticCastOp<"bitcast"> { - let summary = "bitcast between values of equal bit width"; - let description = [{ - Bitcast an integer or floating point value to an integer or floating point - value of equal bit width. When operating on vectors, casts elementwise. - - Note that this implements a logical bitcast independent of target - endianness. This allows constant folding without target information and is - consitent with the bitcast constant folders in LLVM (see - https://github.com/llvm/llvm-project/blob/18c19414eb/llvm/lib/IR/ConstantFold.cpp#L168) - For targets where the source and target type have the same endianness (which - is the standard), this cast will also change no bits at runtime, but it may - still require an operation, for example if the machine has different - floating point and integer register files. For targets that have a different - endianness for the source and target types (e.g. float is big-endian and - integer is little-endian) a proper lowering would add operations to swap the - order of words in addition to the bitcast. - }]; - let hasFolder = 1; -} - - //===----------------------------------------------------------------------===// // BranchOp //===----------------------------------------------------------------------===// @@ -666,240 +509,6 @@ let assemblyFormat = "$callee `(` $operands `)` attr-dict `:` type($callee)"; } -//===----------------------------------------------------------------------===// -// CeilFOp -//===----------------------------------------------------------------------===// - -def CeilFOp : FloatUnaryOp<"ceilf"> { - let summary = "ceiling of the specified value"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.ceilf` ssa-use `:` type - ``` - - The `ceilf` operation computes the ceiling of a given value. It takes one - operand and returns one result of the same type. This type may be a float - scalar type, a vector whose element type is float, or a tensor of floats. - It has no standard attributes. - - Example: - - ```mlir - // Scalar ceiling value. - %a = ceilf %b : f64 - - // SIMD vector element-wise ceiling value. - %f = ceilf %g : vector<4xf32> - - // Tensor element-wise ceiling value. - %x = ceilf %y : tensor<4x?xf8> - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// FloorFOp -//===----------------------------------------------------------------------===// - -def FloorFOp : FloatUnaryOp<"floorf"> { - let summary = "floor of the specified value"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.floorf` ssa-use `:` type - ``` - - The `floorf` operation computes the floor of a given value. It takes one - operand and returns one result of the same type. This type may be a float - scalar type, a vector whose element type is float, or a tensor of floats. - It has no standard attributes. - - Example: - - ```mlir - // Scalar floor value. - %a = floorf %b : f64 - - // SIMD vector element-wise floor value. - %f = floorf %g : vector<4xf32> - - // Tensor element-wise floor value. - %x = floorf %y : tensor<4x?xf8> - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// CmpFOp -//===----------------------------------------------------------------------===// - -def CmpFOp : Std_Op<"cmpf", [NoSideEffect, SameTypeOperands, - DeclareOpInterfaceMethods, TypesMatchWith< - "result type has i1 element type and same shape as operands", - "lhs", "result", "getI1SameShape($_self)">] # ElementwiseMappable.traits> { - let summary = "floating-point comparison operation"; - let description = [{ - The `cmpf` operation compares its two operands according to the float - comparison rules and the predicate specified by the respective attribute. - The predicate defines the type of comparison: (un)orderedness, (in)equality - and signed less/greater than (or equal to) as well as predicates that are - always true or false. The operands must have the same type, and this type - must be a float type, or a vector or tensor thereof. The result is an i1, - or a vector/tensor thereof having the same shape as the inputs. Unlike cmpi, - the operands are always treated as signed. The u prefix indicates - *unordered* comparison, not unsigned comparison, so "une" means unordered or - not equal. For the sake of readability by humans, custom assembly form for - the operation uses a string-typed attribute for the predicate. The value of - this attribute corresponds to lower-cased name of the predicate constant, - e.g., "one" means "ordered not equal". The string representation of the - attribute is merely a syntactic sugar and is converted to an integer - attribute by the parser. - - Example: - - ```mlir - %r1 = cmpf "oeq" %0, %1 : f32 - %r2 = cmpf "ult" %0, %1 : tensor<42x42xf64> - %r3 = "std.cmpf"(%0, %1) {predicate: 0} : (f8, f8) -> i1 - ``` - }]; - - let arguments = (ins - CmpFPredicateAttr:$predicate, - FloatLike:$lhs, - FloatLike:$rhs - ); - let results = (outs BoolLike:$result); - - let builders = [ - OpBuilder<(ins "CmpFPredicate":$predicate, "Value":$lhs, - "Value":$rhs), [{ - ::buildCmpFOp($_builder, $_state, predicate, lhs, rhs); - }]>]; - - let extraClassDeclaration = [{ - static StringRef getPredicateAttrName() { return "predicate"; } - static CmpFPredicate getPredicateByName(StringRef name); - - CmpFPredicate getPredicate() { - return (CmpFPredicate)(*this)->getAttrOfType( - getPredicateAttrName()).getInt(); - } - }]; - - let verifier = [{ return success(); }]; - - let hasFolder = 1; - - let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)"; -} - -//===----------------------------------------------------------------------===// -// CmpIOp -//===----------------------------------------------------------------------===// - -def CmpIOp : Std_Op<"cmpi", [NoSideEffect, SameTypeOperands, - DeclareOpInterfaceMethods, TypesMatchWith< - "result type has i1 element type and same shape as operands", - "lhs", "result", "getI1SameShape($_self)">] # ElementwiseMappable.traits> { - let summary = "integer comparison operation"; - let description = [{ - The `cmpi` operation is a generic comparison for integer-like types. Its two - arguments can be integers, vectors or tensors thereof as long as their types - match. The operation produces an i1 for the former case, a vector or a - tensor of i1 with the same shape as inputs in the other cases. - - Its first argument is an attribute that defines which type of comparison is - performed. The following comparisons are supported: - - - equal (mnemonic: `"eq"`; integer value: `0`) - - not equal (mnemonic: `"ne"`; integer value: `1`) - - signed less than (mnemonic: `"slt"`; integer value: `2`) - - signed less than or equal (mnemonic: `"sle"`; integer value: `3`) - - signed greater than (mnemonic: `"sgt"`; integer value: `4`) - - signed greater than or equal (mnemonic: `"sge"`; integer value: `5`) - - unsigned less than (mnemonic: `"ult"`; integer value: `6`) - - unsigned less than or equal (mnemonic: `"ule"`; integer value: `7`) - - unsigned greater than (mnemonic: `"ugt"`; integer value: `8`) - - unsigned greater than or equal (mnemonic: `"uge"`; integer value: `9`) - - The result is `1` if the comparison is true and `0` otherwise. For vector or - tensor operands, the comparison is performed elementwise and the element of - the result indicates whether the comparison is true for the operand elements - with the same indices as those of the result. - - Note: while the custom assembly form uses strings, the actual underlying - attribute has integer type (or rather enum class in C++ code) as seen from - the generic assembly form. String literals are used to improve readability - of the IR by humans. - - This operation only applies to integer-like operands, but not floats. The - main reason being that comparison operations have diverging sets of - attributes: integers require sign specification while floats require various - floating point-related particularities, e.g., `-ffast-math` behavior, - IEEE754 compliance, etc - ([rationale](../Rationale/Rationale.md#splitting-floating-point-vs-integer-operations)). - The type of comparison is specified as attribute to avoid introducing ten - similar operations, taking into account that they are often implemented - using the same operation downstream - ([rationale](../Rationale/Rationale.md#specifying-comparison-kind-as-attribute)). The - separation between signed and unsigned order comparisons is necessary - because of integers being signless. The comparison operation must know how - to interpret values with the foremost bit being set: negatives in two's - complement or large positives - ([rationale](../Rationale/Rationale.md#specifying-sign-in-integer-comparison-operations)). - - Example: - - ```mlir - // Custom form of scalar "signed less than" comparison. - %x = cmpi "slt", %lhs, %rhs : i32 - - // Generic form of the same operation. - %x = "std.cmpi"(%lhs, %rhs) {predicate = 2 : i64} : (i32, i32) -> i1 - - // Custom form of vector equality comparison. - %x = cmpi "eq", %lhs, %rhs : vector<4xi64> - - // Generic form of the same operation. - %x = "std.cmpi"(%lhs, %rhs) {predicate = 0 : i64} - : (vector<4xi64>, vector<4xi64>) -> vector<4xi1> - ``` - }]; - - let arguments = (ins - CmpIPredicateAttr:$predicate, - SignlessIntegerLike:$lhs, - SignlessIntegerLike:$rhs - ); - let results = (outs BoolLike:$result); - - let builders = [ - OpBuilder<(ins "CmpIPredicate":$predicate, "Value":$lhs, - "Value":$rhs), [{ - ::buildCmpIOp($_builder, $_state, predicate, lhs, rhs); - }]>]; - - let extraClassDeclaration = [{ - static StringRef getPredicateAttrName() { return "predicate"; } - static CmpIPredicate getPredicateByName(StringRef name); - - CmpIPredicate getPredicate() { - return (CmpIPredicate)(*this)->getAttrOfType( - getPredicateAttrName()).getInt(); - } - }]; - - let verifier = [{ return success(); }]; - - let hasFolder = 1; - - let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)"; -} - //===----------------------------------------------------------------------===// // CondBranchOp //===----------------------------------------------------------------------===// @@ -1095,264 +704,111 @@ } //===----------------------------------------------------------------------===// -// CopySignOp +// MaxFOp //===----------------------------------------------------------------------===// -def CopySignOp : FloatBinaryOp<"copysign"> { - let summary = "A copysign operation"; +def MaxFOp : FloatBinaryOp<"maxf"> { + let summary = "floating-point maximum operation"; let description = [{ Syntax: ``` - operation ::= ssa-id `=` `std.copysign` ssa-use `,` ssa-use `:` type + operation ::= ssa-id `=` `maxf` ssa-use `,` ssa-use `:` type ``` - The `copysign` returns a value with the magnitude of the first operand and - the sign of the second operand. It takes two operands and returns one - result of the same type. This type may be a float scalar type, a vector - whose element type is float, or a tensor of floats. It has no standard - attributes. + Returns the maximum of the two arguments, treating -0.0 as less than +0.0. + If one of the arguments is NaN, then the result is also NaN. Example: ```mlir - // Scalar copysign value. - %a = copysign %b, %c : f64 - - // SIMD vector element-wise copysign value. - %f = copysign %g, %h : vector<4xf32> - - // Tensor element-wise copysign value. - %x = copysign %y, %z : tensor<4x?xf8> + // Scalar floating-point maximum. + %a = maxf %b, %c : f64 ``` }]; } //===----------------------------------------------------------------------===// -// DivFOp -//===----------------------------------------------------------------------===// - -def DivFOp : FloatBinaryOp<"divf"> { - let summary = "floating point division operation"; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// FmaFOp +// MaxSIOp //===----------------------------------------------------------------------===// -def FmaFOp : FloatTernaryOp<"fmaf"> { - let summary = "floating point fused multipy-add operation"; +def MaxSIOp : IntBinaryOp<"maxsi"> { + let summary = "signed integer maximum operation"; let description = [{ Syntax: ``` - operation ::= ssa-id `=` `std.fmaf` ssa-use `,` ssa-use `,` ssa-use `:` type + operation ::= ssa-id `=` `maxsi` ssa-use `,` ssa-use `:` type ``` - The `fmaf` operation takes three operands and returns one result, each of - these is required to be the same type. This type may be a floating point - scalar type, a vector whose element type is a floating point type, or a - floating point tensor. + Returns the larger of %a and %b comparing the values as signed integers. Example: ```mlir - // Scalar fused multiply-add: d = a*b + c - %d = fmaf %a, %b, %c : f64 - - // SIMD vector fused multiply-add, e.g. for Intel SSE. - %i = fmaf %f, %g, %h : vector<4xf32> - - // Tensor fused multiply-add. - %w = fmaf %x, %y, %z : tensor<4x?xbf16> + // Scalar signed integer maximum. + %a = maxsi %b, %c : i64 ``` - - The semantics of the operation correspond to those of the `llvm.fma` - [intrinsic](https://llvm.org/docs/LangRef.html#llvm-fma-intrinsic). In the - particular case of lowering to LLVM, this is guaranteed to lower - to the `llvm.fma.*` intrinsic. }]; } //===----------------------------------------------------------------------===// -// FPExtOp +// MaxUIOp //===----------------------------------------------------------------------===// -def FPExtOp : ArithmeticCastOp<"fpext"> { - let summary = "cast from floating-point to wider floating-point"; +def MaxUIOp : IntBinaryOp<"maxui"> { + let summary = "unsigned integer maximum operation"; let description = [{ - Cast a floating-point value to a larger floating-point-typed value. - The destination type must to be strictly wider than the source type. - When operating on vectors, casts elementwise. - }]; -} + Syntax: -//===----------------------------------------------------------------------===// -// FPToSIOp -//===----------------------------------------------------------------------===// + ``` + operation ::= ssa-id `=` `maxui` ssa-use `,` ssa-use `:` type + ``` -def FPToSIOp : ArithmeticCastOp<"fptosi"> { - let summary = "cast from floating-point type to integer type"; - let description = [{ - Cast from a value interpreted as floating-point to the nearest (rounding - towards zero) signed integer value. When operating on vectors, casts - elementwise. - }]; -} + Returns the larger of %a and %b comparing the values as unsigned integers. -//===----------------------------------------------------------------------===// -// FPToUIOp -//===----------------------------------------------------------------------===// + Example: -def FPToUIOp : ArithmeticCastOp<"fptoui"> { - let summary = "cast from floating-point type to integer type"; - let description = [{ - Cast from a value interpreted as floating-point to the nearest (rounding - towards zero) unsigned integer value. When operating on vectors, casts - elementwise. + ```mlir + // Scalar unsigned integer maximum. + %a = maxui %b, %c : i64 + ``` }]; } //===----------------------------------------------------------------------===// -// FPTruncOp +// MinFOp //===----------------------------------------------------------------------===// -def FPTruncOp : ArithmeticCastOp<"fptrunc"> { - let summary = "cast from floating-point to narrower floating-point"; +def MinFOp : FloatBinaryOp<"minf"> { + let summary = "floating-point minimum operation"; let description = [{ - Truncate a floating-point value to a smaller floating-point-typed value. - The destination type must be strictly narrower than the source type. - If the value cannot be exactly represented, it is rounded using the default - rounding mode. When operating on vectors, casts elementwise. - }]; + Syntax: - let hasFolder = 1; + ``` + operation ::= ssa-id `=` `minf` ssa-use `,` ssa-use `:` type + ``` + + Returns the minimum of the two arguments, treating -0.0 as less than +0.0. + If one of the arguments is NaN, then the result is also NaN. + + Example: + + ```mlir + // Scalar floating-point minimum. + %a = minf %b, %c : f64 + ``` + }]; } //===----------------------------------------------------------------------===// -// IndexCastOp +// MinSIOp //===----------------------------------------------------------------------===// -def IndexCastOp : ArithmeticCastOp<"index_cast"> { - let summary = "cast between index and integer types"; +def MinSIOp : IntBinaryOp<"minsi"> { + let summary = "signed integer minimum operation"; let description = [{ - Casts between scalar or vector integers and corresponding 'index' scalar or - vectors. Index is an integer of platform-specific bit width. If casting to - a wider integer, the value is sign-extended. If casting to a narrower - integer, the value is truncated. - }]; - - let hasFolder = 1; - let hasCanonicalizer = 1; -} - -//===----------------------------------------------------------------------===// -// MaxFOp -//===----------------------------------------------------------------------===// - -def MaxFOp : FloatBinaryOp<"maxf"> { - let summary = "floating-point maximum operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `maxf` ssa-use `,` ssa-use `:` type - ``` - - Returns the maximum of the two arguments, treating -0.0 as less than +0.0. - If one of the arguments is NaN, then the result is also NaN. - - Example: - - ```mlir - // Scalar floating-point maximum. - %a = maxf %b, %c : f64 - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// MaxSIOp -//===----------------------------------------------------------------------===// - -def MaxSIOp : IntBinaryOp<"maxsi"> { - let summary = "signed integer maximum operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `maxsi` ssa-use `,` ssa-use `:` type - ``` - - Returns the larger of %a and %b comparing the values as signed integers. - - Example: - - ```mlir - // Scalar signed integer maximum. - %a = maxsi %b, %c : i64 - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// MaxUIOp -//===----------------------------------------------------------------------===// - -def MaxUIOp : IntBinaryOp<"maxui"> { - let summary = "unsigned integer maximum operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `maxui` ssa-use `,` ssa-use `:` type - ``` - - Returns the larger of %a and %b comparing the values as unsigned integers. - - Example: - - ```mlir - // Scalar unsigned integer maximum. - %a = maxui %b, %c : i64 - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// MinFOp -//===----------------------------------------------------------------------===// - -def MinFOp : FloatBinaryOp<"minf"> { - let summary = "floating-point minimum operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `minf` ssa-use `,` ssa-use `:` type - ``` - - Returns the minimum of the two arguments, treating -0.0 as less than +0.0. - If one of the arguments is NaN, then the result is also NaN. - - Example: - - ```mlir - // Scalar floating-point minimum. - %a = minf %b, %c : f64 - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// MinSIOp -//===----------------------------------------------------------------------===// - -def MinSIOp : IntBinaryOp<"minsi"> { - let summary = "signed integer minimum operation"; - let description = [{ - Syntax: + Syntax: ``` operation ::= ssa-id `=` `minsi` ssa-use `,` ssa-use `:` type @@ -1393,119 +849,6 @@ }]; } -//===----------------------------------------------------------------------===// -// MulFOp -//===----------------------------------------------------------------------===// - -def MulFOp : FloatBinaryOp<"mulf"> { - let summary = "floating point multiplication operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.mulf` ssa-use `,` ssa-use `:` type - ``` - - The `mulf` operation takes two operands and returns one result, each of - these is required to be the same type. This type may be a floating point - scalar type, a vector whose element type is a floating point type, or a - floating point tensor. - - Example: - - ```mlir - // Scalar multiplication. - %a = mulf %b, %c : f64 - - // SIMD pointwise vector multiplication, e.g. for Intel SSE. - %f = mulf %g, %h : vector<4xf32> - - // Tensor pointwise multiplication. - %x = mulf %y, %z : tensor<4x?xbf16> - ``` - - TODO: In the distant future, this will accept optional attributes for fast - math, contraction, rounding mode, and other controls. - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// MulIOp -//===----------------------------------------------------------------------===// - -def MulIOp : IntBinaryOp<"muli", [Commutative]> { - let summary = "integer multiplication operation"; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// NegFOp -//===----------------------------------------------------------------------===// - -def NegFOp : FloatUnaryOp<"negf"> { - let summary = "floating point negation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `negf` ssa-use `:` type - ``` - - The `negf` operation computes the negation of a given value. It takes one - operand and returns one result of the same type. This type may be a float - scalar type, a vector whose element type is float, or a tensor of floats. - It has no standard attributes. - - Example: - - ```mlir - // Scalar negation value. - %a = negf %b : f64 - - // SIMD vector element-wise negation value. - %f = negf %g : vector<4xf32> - - // Tensor element-wise negation value. - %x = negf %y : tensor<4x?xf8> - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// OrOp -//===----------------------------------------------------------------------===// - -def OrOp : IntBinaryOp<"or", [Commutative]> { - let summary = "integer binary or"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `or` ssa-use `,` ssa-use `:` type - ``` - - The `or` operation takes two operands and returns one result, each of these - is required to be the same type. This type may be an integer scalar type, a - vector whose element type is integer, or a tensor of integers. It has no - standard attributes. - - Example: - - ```mlir - // Scalar integer bitwise or. - %a = or %b, %c : i64 - - // SIMD vector element-wise bitwise integer or. - %f = or %g, %h : vector<4xi32> - - // Tensor element-wise bitwise integer or. - %x = or %y, %z : tensor<4x?xi8> - ``` - }]; - let hasFolder = 1; -} - //===----------------------------------------------------------------------===// // RankOp //===----------------------------------------------------------------------===// @@ -1538,14 +881,6 @@ let assemblyFormat = "$memrefOrTensor attr-dict `:` type($memrefOrTensor)"; } -//===----------------------------------------------------------------------===// -// RemFOp -//===----------------------------------------------------------------------===// - -def RemFOp : FloatBinaryOp<"remf"> { - let summary = "floating point division remainder operation"; -} - //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// @@ -1641,236 +976,6 @@ let hasFolder = 1; } -//===----------------------------------------------------------------------===// -// ShiftLeftOp -//===----------------------------------------------------------------------===// - -def ShiftLeftOp : IntBinaryOp<"shift_left"> { - let summary = "integer left-shift"; - let description = [{ - The shift_left operation shifts an integer value to the left by a variable - amount. The low order bits are filled with zeros. - - Example: - - ```mlir - %1 = constant 5 : i8 // %1 is 0b00000101 - %2 = constant 3 : i8 - %3 = shift_left %1, %2 : (i8, i8) -> i8 // %3 is 0b00101000 - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// SignedDivIOp -//===----------------------------------------------------------------------===// - -def SignedDivIOp : IntBinaryOp<"divi_signed"> { - let summary = "signed integer division operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `divi_signed` ssa-use `,` ssa-use `:` type - ``` - - Signed integer division. Rounds towards zero. Treats the leading bit as - sign, i.e. `6 / -2 = -3`. - - Note: the semantics of division by zero or signed division overflow (minimum - value divided by -1) is TBD; do NOT assume any specific behavior. - - Example: - - ```mlir - // Scalar signed integer division. - %a = divi_signed %b, %c : i64 - - // SIMD vector element-wise division. - %f = divi_signed %g, %h : vector<4xi32> - - // Tensor element-wise integer division. - %x = divi_signed %y, %z : tensor<4x?xi8> - ``` - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// SignedFloorDivIOp -//===----------------------------------------------------------------------===// - -def SignedFloorDivIOp : IntBinaryOp<"floordivi_signed"> { - let summary = "signed floor integer division operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `floordivi_signed` ssa-use `,` ssa-use `:` type - ``` - - Signed integer division. Rounds towards negative infinity, i.e. `5 / -2 = -3`. - - Note: the semantics of division by zero or signed division overflow (minimum - value divided by -1) is TBD; do NOT assume any specific behavior. - - Example: - - ```mlir - // Scalar signed integer division. - %a = floordivi_signed %b, %c : i64 - - ``` - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// SignedCeilDivIOp -//===----------------------------------------------------------------------===// - -def SignedCeilDivIOp : IntBinaryOp<"ceildivi_signed"> { - let summary = "signed ceil integer division operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `ceildivi_signed` ssa-use `,` ssa-use `:` type - ``` - - Signed integer division. Rounds towards positive infinity, i.e. `7 / -2 = -3`. - - Note: the semantics of division by zero or signed division overflow (minimum - value divided by -1) is TBD; do NOT assume any specific behavior. - - Example: - - ```mlir - // Scalar signed integer division. - %a = ceildivi_signed %b, %c : i64 - ``` - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// SignedRemIOp -//===----------------------------------------------------------------------===// - -def SignedRemIOp : IntBinaryOp<"remi_signed"> { - let summary = "signed integer division remainder operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.remi_signed` ssa-use `,` ssa-use `:` type - ``` - - Signed integer division remainder. Treats the leading bit as sign, i.e. `6 % - -2 = 0`. - - Note: the semantics of division by zero is TBD; do NOT assume any specific - behavior. - - Example: - - ```mlir - // Scalar signed integer division remainder. - %a = remi_signed %b, %c : i64 - - // SIMD vector element-wise division remainder. - %f = remi_signed %g, %h : vector<4xi32> - - // Tensor element-wise integer division remainder. - %x = remi_signed %y, %z : tensor<4x?xi8> - ``` - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// SignedShiftRightOp -//===----------------------------------------------------------------------===// - -def SignedShiftRightOp : IntBinaryOp<"shift_right_signed"> { - let summary = "signed integer right-shift"; - let description = [{ - The shift_right_signed operation shifts an integer value to the right by - a variable amount. The integer is interpreted as signed. The high order - bits in the output are filled with copies of the most-significant bit - of the shifted value (which means that the sign of the value is preserved). - - Example: - - ```mlir - %1 = constant 160 : i8 // %1 is 0b10100000 - %2 = constant 3 : i8 - %3 = shift_right_signed %1, %2 : (i8, i8) -> i8 // %3 is 0b11110100 - %4 = constant 96 : i8 // %4 is 0b01100000 - %5 = shift_right_signed %4, %2 : (i8, i8) -> i8 // %5 is 0b00001100 - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// SignExtendIOp -//===----------------------------------------------------------------------===// - -def SignExtendIOp : Std_Op<"sexti", [NoSideEffect, - DeclareOpInterfaceMethods] # - ElementwiseMappable.traits> { - let summary = "integer sign extension operation"; - let description = [{ - The integer sign extension operation takes an integer input of - width M and an integer destination type of width N. The destination - bit-width must be larger than the input bit-width (N > M). - The top-most (N - M) bits of the output are filled with copies - of the most-significant bit of the input. - - Example: - - ```mlir - %1 = constant 5 : i3 // %1 is 0b101 - %2 = sexti %1 : i3 to i6 // %2 is 0b111101 - %3 = constant 2 : i3 // %3 is 0b010 - %4 = sexti %3 : i3 to i6 // %4 is 0b000010 - - %5 = sexti %0 : vector<2 x i32> to vector<2 x i64> - ``` - }]; - - let arguments = (ins SignlessIntegerLike:$value); - let results = (outs SignlessIntegerLike); - - let builders = [ - OpBuilder<(ins "Value":$value, "Type":$destType), [{ - $_state.addOperands(value); - $_state.addTypes(destType); - }]>]; - - let parser = [{ - return impl::parseCastOp(parser, result); - }]; - let printer = [{ - return printStandardCastOp(this->getOperation(), p); - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// SIToFPOp -//===----------------------------------------------------------------------===// - -def SIToFPOp : ArithmeticCastOp<"sitofp"> { - let summary = "cast from integer type to floating-point"; - let description = [{ - Cast from a value interpreted as a signed integer to the corresponding - floating-point value. If the value cannot be exactly represented, it is - rounded using the default rounding mode. When operating on vectors, casts - elementwise. - }]; -} - //===----------------------------------------------------------------------===// // SplatOp //===----------------------------------------------------------------------===// @@ -1918,25 +1023,6 @@ let assemblyFormat = "$input attr-dict `:` type($aggregate)"; } -//===----------------------------------------------------------------------===// -// SubFOp -//===----------------------------------------------------------------------===// - -def SubFOp : FloatBinaryOp<"subf"> { - let summary = "floating point subtraction operation"; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// SubIOp -//===----------------------------------------------------------------------===// - -def SubIOp : IntBinaryOp<"subi"> { - let summary = "integer subtraction operation"; - let hasFolder = 1; - let hasCanonicalizer = 1; -} - //===----------------------------------------------------------------------===// // SwitchOp //===----------------------------------------------------------------------===// @@ -2025,225 +1111,4 @@ let hasCanonicalizer = 1; } -//===----------------------------------------------------------------------===// -// TruncateIOp -//===----------------------------------------------------------------------===// - -def TruncateIOp : Std_Op<"trunci", [NoSideEffect, - DeclareOpInterfaceMethods] # - ElementwiseMappable.traits> { - let summary = "integer truncation operation"; - let description = [{ - The integer truncation operation takes an integer input of - width M and an integer destination type of width N. The destination - bit-width must be smaller than the input bit-width (N < M). - The top-most (N - M) bits of the input are discarded. - - Example: - - ```mlir - %1 = constant 21 : i5 // %1 is 0b10101 - %2 = trunci %1 : i5 to i4 // %2 is 0b0101 - %3 = trunci %1 : i5 to i3 // %3 is 0b101 - - %5 = trunci %0 : vector<2 x i32> to vector<2 x i16> - ``` - }]; - - let arguments = (ins SignlessIntegerLike:$value); - let results = (outs SignlessIntegerLike); - - let builders = [ - OpBuilder<(ins "Value":$value, "Type":$destType), [{ - $_state.addOperands(value); - $_state.addTypes(destType); - }]>]; - - let parser = [{ - return impl::parseCastOp(parser, result); - }]; - let printer = [{ - return printStandardCastOp(this->getOperation(), p); - }]; - - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// UIToFPOp -//===----------------------------------------------------------------------===// - -def UIToFPOp : ArithmeticCastOp<"uitofp"> { - let summary = "cast from unsigned integer type to floating-point"; - let description = [{ - Cast from a value interpreted as unsigned integer to the corresponding - floating-point value. If the value cannot be exactly represented, it is - rounded using the default rounding mode. When operating on vectors, casts - elementwise. - }]; -} - -//===----------------------------------------------------------------------===// -// UnsignedDivIOp -//===----------------------------------------------------------------------===// - -def UnsignedDivIOp : IntBinaryOp<"divi_unsigned"> { - let summary = "unsigned integer division operation"; - let description = [{ - Syntax: - ``` - operation ::= ssa-id `=` `std.divi_unsigned` ssa-use `,` ssa-use `:` type - ``` - - Unsigned integer division. Rounds towards zero. Treats the leading bit as - the most significant, i.e. for `i16` given two's complement representation, - `6 / -2 = 6 / (2^16 - 2) = 0`. - - Note: the semantics of division by zero is TBD; do NOT assume any specific - behavior. - - Example: - - ```mlir - // Scalar unsigned integer division. - %a = divi_unsigned %b, %c : i64 - - // SIMD vector element-wise division. - %f = divi_unsigned %g, %h : vector<4xi32> - - // Tensor element-wise integer division. - %x = divi_unsigned %y, %z : tensor<4x?xi8> - ``` - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// UnsignedRemIOp -//===----------------------------------------------------------------------===// - -def UnsignedRemIOp : IntBinaryOp<"remi_unsigned"> { - let summary = "unsigned integer division remainder operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.remi_unsigned` ssa-use `,` ssa-use `:` type - ``` - - Unsigned integer division remainder. Treats the leading bit as the most - significant, i.e. for `i16`, `6 % -2 = 6 % (2^16 - 2) = 6`. - - Note: the semantics of division by zero is TBD; do NOT assume any specific - behavior. - - Example: - - ```mlir - // Scalar unsigned integer division remainder. - %a = remi_unsigned %b, %c : i64 - - // SIMD vector element-wise division remainder. - %f = remi_unsigned %g, %h : vector<4xi32> - - // Tensor element-wise integer division remainder. - %x = remi_unsigned %y, %z : tensor<4x?xi8> - ``` - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// UnsignedShiftRightOp -//===----------------------------------------------------------------------===// - -def UnsignedShiftRightOp : IntBinaryOp<"shift_right_unsigned"> { - let summary = "unsigned integer right-shift"; - let description = [{ - The shift_right_unsigned operation shifts an integer value to the right by - a variable amount. The integer is interpreted as unsigned. The high order - bits are always filled with zeros. - - Example: - - ```mlir - %1 = constant 160 : i8 // %1 is 0b10100000 - %2 = constant 3 : i8 - %3 = shift_right_unsigned %1, %2 : (i8, i8) -> i8 // %3 is 0b00010100 - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// XOrOp -//===----------------------------------------------------------------------===// - -def XOrOp : IntBinaryOp<"xor", [Commutative]> { - let summary = "integer binary xor"; - let description = [{ - The `xor` operation takes two operands and returns one result, each of these - is required to be the same type. This type may be an integer scalar type, a - vector whose element type is integer, or a tensor of integers. It has no - standard attributes. - - Example: - - ```mlir - // Scalar integer bitwise xor. - %a = xor %b, %c : i64 - - // SIMD vector element-wise bitwise integer xor. - %f = xor %g, %h : vector<4xi32> - - // Tensor element-wise bitwise integer xor. - %x = xor %y, %z : tensor<4x?xi8> - ``` - }]; - let hasFolder = 1; - let hasCanonicalizer = 1; -} - -//===----------------------------------------------------------------------===// -// ZeroExtendIOp -//===----------------------------------------------------------------------===// - -def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect, - DeclareOpInterfaceMethods] # - ElementwiseMappable.traits> { - let summary = "integer zero extension operation"; - let description = [{ - The integer zero extension operation takes an integer input of - width M and an integer destination type of width N. The destination - bit-width must be larger than the input bit-width (N > M). - The top-most (N - M) bits of the output are filled with zeros. - - Example: - - ```mlir - %1 = constant 5 : i3 // %1 is 0b101 - %2 = zexti %1 : i3 to i6 // %2 is 0b000101 - %3 = constant 2 : i3 // %3 is 0b010 - %4 = zexti %3 : i3 to i6 // %4 is 0b000010 - - %5 = zexti %0 : vector<2 x i32> to vector<2 x i64> - ``` - }]; - - let arguments = (ins SignlessIntegerLike:$value); - let results = (outs SignlessIntegerLike); - - let builders = [ - OpBuilder<(ins "Value":$value, "Type":$destType), [{ - $_state.addOperands(value); - $_state.addTypes(destType); - }]>]; - - let parser = [{ - return impl::parseCastOp(parser, result); - }]; - let printer = [{ - return printStandardCastOp(this->getOperation(), p); - }]; -} - #endif // STANDARD_OPS diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td b/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td @@ -36,50 +36,4 @@ let cppNamespace = "::mlir"; } -// The predicate indicates the type of the comparison to perform: -// (un)orderedness, (in)equality and less/greater than (or equal to) as -// well as predicates that are always true or false. -def CMPF_P_FALSE : I64EnumAttrCase<"AlwaysFalse", 0, "false">; -def CMPF_P_OEQ : I64EnumAttrCase<"OEQ", 1, "oeq">; -def CMPF_P_OGT : I64EnumAttrCase<"OGT", 2, "ogt">; -def CMPF_P_OGE : I64EnumAttrCase<"OGE", 3, "oge">; -def CMPF_P_OLT : I64EnumAttrCase<"OLT", 4, "olt">; -def CMPF_P_OLE : I64EnumAttrCase<"OLE", 5, "ole">; -def CMPF_P_ONE : I64EnumAttrCase<"ONE", 6, "one">; -def CMPF_P_ORD : I64EnumAttrCase<"ORD", 7, "ord">; -def CMPF_P_UEQ : I64EnumAttrCase<"UEQ", 8, "ueq">; -def CMPF_P_UGT : I64EnumAttrCase<"UGT", 9, "ugt">; -def CMPF_P_UGE : I64EnumAttrCase<"UGE", 10, "uge">; -def CMPF_P_ULT : I64EnumAttrCase<"ULT", 11, "ult">; -def CMPF_P_ULE : I64EnumAttrCase<"ULE", 12, "ule">; -def CMPF_P_UNE : I64EnumAttrCase<"UNE", 13, "une">; -def CMPF_P_UNO : I64EnumAttrCase<"UNO", 14, "uno">; -def CMPF_P_TRUE : I64EnumAttrCase<"AlwaysTrue", 15, "true">; - -def CmpFPredicateAttr : I64EnumAttr< - "CmpFPredicate", "", - [CMPF_P_FALSE, CMPF_P_OEQ, CMPF_P_OGT, CMPF_P_OGE, CMPF_P_OLT, CMPF_P_OLE, - CMPF_P_ONE, CMPF_P_ORD, CMPF_P_UEQ, CMPF_P_UGT, CMPF_P_UGE, CMPF_P_ULT, - CMPF_P_ULE, CMPF_P_UNE, CMPF_P_UNO, CMPF_P_TRUE]> { - let cppNamespace = "::mlir"; -} - -def CMPI_P_EQ : I64EnumAttrCase<"eq", 0>; -def CMPI_P_NE : I64EnumAttrCase<"ne", 1>; -def CMPI_P_SLT : I64EnumAttrCase<"slt", 2>; -def CMPI_P_SLE : I64EnumAttrCase<"sle", 3>; -def CMPI_P_SGT : I64EnumAttrCase<"sgt", 4>; -def CMPI_P_SGE : I64EnumAttrCase<"sge", 5>; -def CMPI_P_ULT : I64EnumAttrCase<"ult", 6>; -def CMPI_P_ULE : I64EnumAttrCase<"ule", 7>; -def CMPI_P_UGT : I64EnumAttrCase<"ugt", 8>; -def CMPI_P_UGE : I64EnumAttrCase<"uge", 9>; - -def CmpIPredicateAttr : I64EnumAttr< - "CmpIPredicate", "", - [CMPI_P_EQ, CMPI_P_NE, CMPI_P_SLT, CMPI_P_SLE, CMPI_P_SGT, - CMPI_P_SGE, CMPI_P_ULT, CMPI_P_ULE, CMPI_P_UGT, CMPI_P_UGE]> { - let cppNamespace = "::mlir"; -} - #endif // STANDARD_OPS_BASE diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -43,7 +43,7 @@ /// Creates an instance of the StdExpand pass that legalizes Std /// dialect ops to be convertible to LLVM. For example, -/// `std.ceildivi_signed` gets transformed to a number of std operations, +/// `std.arith.ceildivsi` gets transformed to a number of std operations, /// which can be lowered to LLVM; `memref.reshape` gets converted to /// `memref_reinterpret_cast`. std::unique_ptr createStdExpandOpsPass(); diff --git a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h --- a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h @@ -16,6 +16,7 @@ #ifndef MLIR_DIALECT_STANDARDOPS_UTILS_UTILS_H #define MLIR_DIALECT_STANDARDOPS_UTILS_UTILS_H +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" @@ -24,7 +25,7 @@ namespace mlir { /// Matches a ConstantIndexOp. -detail::op_matcher matchConstantIndex(); +detail::op_matcher matchConstantIndex(); /// Detects the `values` produced by a ConstantIndexOp and places the new /// constant in place of the corresponding sentinel value. diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_TENSOR_IR_TENSOR_H_ #define MLIR_DIALECT_TENSOR_IR_TENSOR_H_ +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td @@ -46,6 +46,7 @@ }]; let hasConstantMaterializer = 1; + let dependentDialects = ["arith::ArithmeticDialect"]; } #endif // TENSOR_BASE diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -21,6 +21,7 @@ let name = "vector"; let cppNamespace = "::mlir::vector"; let hasConstantMaterializer = 1; + let dependentDialects = ["arith::ArithmeticDialect"]; } // Base class for Vector dialect ops. @@ -337,7 +338,7 @@ static SmallVector inferDestShape( ArrayRef shape, ArrayRef reducedDimsMask) { - assert(shape.size() == reducedDimsMask.size() && + assert(shape.size() == reducedDimsMask.size() && "shape and maks of different sizes"); SmallVector res; for (auto it : llvm::zip(reducedDimsMask, shape)) @@ -555,7 +556,7 @@ %idx0 = ... : index // dynamic computation producing the value 1 of index type %idx1 = ... : index - %0 = constant dense<0, 1, 2, 3>: vector<4xi32> + %0 = arith.constant dense<0, 1, 2, 3>: vector<4xi32> // extracts values [0, 1] %1 = vector.extract_map %0[%idx0] : vector<4xi32> to vector<2xi32> // extracts values [1, 2] @@ -743,7 +744,7 @@ %idx0 = ... : index // dynamic computation producing the value 1 of index type %idx1 = ... : index / - %0 = constant dense<0, 1, 2, 3>: vector<4xi32> + %0 = arith.constant dense<0, 1, 2, 3>: vector<4xi32> // extracts values [0, 1] %1 = vector.extract_map %0[%idx0] : vector<4xi32> to vector<2xi32> // extracts values [1, 2] diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h --- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -173,9 +173,9 @@ /// canonicalizations pattern to propagate and fold the vector /// insert_map/extract_map operations. /// Transforms: -// %v = addf %a, %b : vector<32xf32> +// %v = arith.addf %a, %b : vector<32xf32> /// to: -/// %v = addf %a, %b : vector<32xf32> +/// %v = arith.addf %a, %b : vector<32xf32> /// %ev = vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32> /// %nv = vector.insert_map %ev, %id, 32 : vector<1xf32> into vector<32xf32> Optional diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td --- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -325,7 +325,7 @@ %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32> %1 = vector.extractelement %0[%i0 : i32]: vector<8xf32> %2 = vector.extractelement %0[%i4 : i32]: vector<8xf32> - %d = addf %1, %2 : f32 + %d = arith.addf %1, %2 : f32 ``` }]; let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a, diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1383,11 +1383,11 @@ /// /// Examples: /// ``` -/// %scalar = "std.addf"(%a, %b) : (f32, f32) -> f32 +/// %scalar = "arith.addf"(%a, %b) : (f32, f32) -> f32 /// ``` /// can be tensorized to /// ``` -/// %tensor = "std.addf"(%a, %b) : (tensor, tensor) +/// %tensor = "arith.addf"(%a, %b) : (tensor, tensor) /// -> tensor /// ``` /// diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -16,6 +16,7 @@ #include "mlir/Dialect/AMX/AMXDialect.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" #include "mlir/Dialect/Async/IR/Async.h" @@ -52,6 +53,7 @@ // clang-format off registry.insert - %3 = addf %2, %2 : f32 + %3 = arith.addf %2, %2 : f32 affine.store %3, %arg0[%arg2] : memref<10xf32> } affine.for %arg2 = 0 to 10 { %2 = affine.load %1[%arg2] : memref<10xf32> - %3 = mulf %2, %2 : f32 + %3 = arith.mulf %2, %2 : f32 affine.store %3, %arg1[%arg2] : memref<10xf32> } return @@ -67,10 +67,10 @@ affine.store %cst, %0[0] : memref<1xf32> affine.store %cst, %1[0] : memref<1xf32> %2 = affine.load %1[0] : memref<1xf32> - %3 = mulf %2, %2 : f32 + %3 = arith.mulf %2, %2 : f32 affine.store %3, %arg1[%arg2] : memref<10xf32> %4 = affine.load %0[0] : memref<1xf32> - %5 = addf %4, %4 : f32 + %5 = arith.addf %4, %4 : f32 affine.store %5, %arg0[%arg2] : memref<10xf32> } return @@ -87,7 +87,7 @@ affine.for %arg6 = 0 to 3 { %0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32> %1 = affine.load %arg1[%arg5, %arg6] : memref<10x10xf32> - %2 = mulf %0, %1 : f32 + %2 = arith.mulf %0, %1 : f32 affine.store %2, %arg3[%arg5, %arg6] : memref<10x10xf32> } } @@ -95,7 +95,7 @@ affine.for %arg6 = 0 to 3 { %0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32> %1 = affine.load %arg2[%arg5, %arg6] : memref<10x10xf32> - %2 = addf %0, %1 : f32 + %2 = arith.addf %0, %1 : f32 affine.store %2, %arg4[%arg5, %arg6] : memref<10x10xf32> } } @@ -111,11 +111,11 @@ affine.for %arg6 = 0 to 3 { %0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32> %1 = affine.load %arg1[%arg5, %arg6] : memref<10x10xf32> - %2 = mulf %0, %1 : f32 + %2 = arith.mulf %0, %1 : f32 affine.store %2, %arg3[%arg5, %arg6] : memref<10x10xf32> %3 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32> %4 = affine.load %arg2[%arg5, %arg6] : memref<10x10xf32> - %5 = addf %3, %4 : f32 + %5 = arith.addf %3, %4 : f32 affine.store %5, %arg4[%arg5, %arg6] : memref<10x10xf32> } } @@ -481,6 +481,7 @@ let summary = "Coalesce nested loops with independent bounds into a single " "loop"; let constructor = "mlir::createLoopCoalescingPass()"; + let dependentDialects = ["arith::ArithmeticDialect"]; } def LoopInvariantCodeMotion : Pass<"loop-invariant-code-motion"> { @@ -524,7 +525,7 @@ %B: index, %C: memref<16xf64>) -> (memref<16xf64, #tile>) { affine.for %arg3 = 0 to 16 { %a = affine.load %A[%arg3] : memref<16xf64, #tile> - %p = mulf %a, %a : f64 + %p = arith.mulf %a, %a : f64 affine.store %p, %A[%arg3] : memref<16xf64, #tile> } %c = alloc() : memref<16xf64, #tile> @@ -540,7 +541,7 @@ -> memref<4x4xf64> { affine.for %arg3 = 0 to 16 { %3 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64> - %4 = mulf %3, %3 : f64 + %4 = arith.mulf %3, %3 : f64 affine.store %4, %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64> } %0 = alloc() : memref<4x4xf64> @@ -566,8 +567,8 @@ %0 = affine.load %arg0[%arg3, %arg5] : memref<8x8xi32, #linear8> %1 = affine.load %arg1[%arg5, %arg4] : memref<8x8xi32, #linear8> %2 = affine.load %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8> - %3 = muli %0, %1 : i32 - %4 = addi %2, %3 : i32 + %3 = arith.muli %0, %1 : i32 + %4 = arith.addi %2, %3 : i32 affine.store %4, %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8> } } @@ -590,8 +591,8 @@ %0 = affine.load %arg0[%arg3 * 8 + %arg5] : memref<64xi32> %1 = affine.load %arg1[%arg5 * 8 + %arg4] : memref<64xi32> %2 = affine.load %arg2[%arg3 * 8 + %arg4] : memref<64xi32> - %3 = muli %0, %1 : i32 - %4 = addi %2, %3 : i32 + %3 = arith.muli %0, %1 : i32 + %4 = arith.addi %2, %3 : i32 affine.store %4, %arg2[%arg3 * 8 + %arg4] : memref<64xi32> } } diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -17,6 +17,7 @@ #include "mlir/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/BuiltinOps.h" @@ -53,10 +54,10 @@ Operation *combinerOp = combinerOps.back(); Optional maybeKind = TypeSwitch>(combinerOp) - .Case([](Operation *) { return AtomicRMWKind::addf; }) - .Case([](Operation *) { return AtomicRMWKind::mulf; }) - .Case([](Operation *) { return AtomicRMWKind::addi; }) - .Case([](Operation *) { return AtomicRMWKind::muli; }) + .Case([](arith::AddFOp) { return AtomicRMWKind::addf; }) + .Case([](arith::MulFOp) { return AtomicRMWKind::mulf; }) + .Case([](arith::AddIOp) { return AtomicRMWKind::addi; }) + .Case([](arith::MulIOp) { return AtomicRMWKind::muli; }) .Default([](Operation *) -> Optional { // TODO: AtomicRMW supports other kinds of reductions this is // currently not detecting, add those when the need arises. @@ -640,10 +641,9 @@ auto symbol = operands[i]; assert(isValidSymbol(symbol)); // Check if the symbol is a constant. - if (auto cOp = symbol.getDefiningOp()) + if (auto cOp = symbol.getDefiningOp()) dependenceDomain->addBound(FlatAffineConstraints::EQ, - valuePosMap.getSymPos(symbol), - cOp.getValue()); + valuePosMap.getSymPos(symbol), cOp.value()); } }; diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -15,6 +15,7 @@ #include "mlir/Analysis/Presburger/Simplex.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/IntegerSet.h" @@ -654,8 +655,8 @@ // Add top level symbol. appendSymbolId(val); // Check if the symbol is a constant. - if (auto constOp = val.getDefiningOp()) - addBound(BoundType::EQ, val, constOp.getValue()); + if (auto constOp = val.getDefiningOp()) + addBound(BoundType::EQ, val, constOp.value()); } LogicalResult diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -37,13 +37,10 @@ mlir-headers LINK_LIBS PUBLIC - MLIRAffine MLIRCallInterfaces MLIRControlFlowInterfaces MLIRDataLayoutInterfaces MLIRInferTypeOpInterface - MLIRLinalg - MLIRSCF ) add_mlir_library(MLIRLoopAnalysis diff --git a/mlir/lib/Analysis/NumberOfExecutions.cpp b/mlir/lib/Analysis/NumberOfExecutions.cpp --- a/mlir/lib/Analysis/NumberOfExecutions.cpp +++ b/mlir/lib/Analysis/NumberOfExecutions.cpp @@ -11,7 +11,6 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/NumberOfExecutions.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/RegionKindInterface.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -11,9 +11,6 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" -#include "mlir/Dialect/SCF/SCF.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Operation.h" #include "mlir/Support/LLVM.h" diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -17,6 +17,7 @@ #include "mlir/Analysis/PresburgerSet.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/IntegerSet.h" #include "llvm/ADT/SmallPtrSet.h" @@ -98,8 +99,8 @@ assert(cst->containsId(value) && "value expected to be present"); if (isValidSymbol(value)) { // Check if the symbol is a constant. - if (auto cOp = value.getDefiningOp()) - cst->addBound(FlatAffineConstraints::EQ, value, cOp.getValue()); + if (auto cOp = value.getDefiningOp()) + cst->addBound(FlatAffineConstraints::EQ, value, cOp.value()); } else if (auto loop = getForInductionVarOwner(value)) { if (failed(cst->addAffineForOpDomain(loop))) return failure(); @@ -517,8 +518,8 @@ assert(isValidSymbol(symbol)); // Check if the symbol is a constant. if (auto *op = symbol.getDefiningOp()) { - if (auto constOp = dyn_cast(op)) { - cst.addBound(FlatAffineConstraints::EQ, symbol, constOp.getValue()); + if (auto constOp = dyn_cast(op)) { + cst.addBound(FlatAffineConstraints::EQ, symbol, constOp.value()); } } } diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -15,6 +15,7 @@ #include "../PassDetail.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -56,11 +57,11 @@ } Value visitAddExpr(AffineBinaryOpExpr expr) { - return buildBinaryExpr(expr); + return buildBinaryExpr(expr); } Value visitMulExpr(AffineBinaryOpExpr expr) { - return buildBinaryExpr(expr); + return buildBinaryExpr(expr); } /// Euclidean modulo operation: negative RHS is not allowed. @@ -89,11 +90,12 @@ auto rhs = visit(expr.getRHS()); assert(lhs && rhs && "unexpected affine expr lowering failure"); - Value remainder = builder.create(loc, lhs, rhs); - Value zeroCst = builder.create(loc, 0); - Value isRemainderNegative = - builder.create(loc, CmpIPredicate::slt, remainder, zeroCst); - Value correctedRemainder = builder.create(loc, remainder, rhs); + Value remainder = builder.create(loc, lhs, rhs); + Value zeroCst = builder.create(loc, 0); + Value isRemainderNegative = builder.create( + loc, arith::CmpIPredicate::slt, remainder, zeroCst); + Value correctedRemainder = + builder.create(loc, remainder, rhs); Value result = builder.create(loc, isRemainderNegative, correctedRemainder, remainder); return result; @@ -126,15 +128,16 @@ auto rhs = visit(expr.getRHS()); assert(lhs && rhs && "unexpected affine expr lowering failure"); - Value zeroCst = builder.create(loc, 0); - Value noneCst = builder.create(loc, -1); - Value negative = - builder.create(loc, CmpIPredicate::slt, lhs, zeroCst); - Value negatedDecremented = builder.create(loc, noneCst, lhs); + Value zeroCst = builder.create(loc, 0); + Value noneCst = builder.create(loc, -1); + Value negative = builder.create( + loc, arith::CmpIPredicate::slt, lhs, zeroCst); + Value negatedDecremented = builder.create(loc, noneCst, lhs); Value dividend = builder.create(loc, negative, negatedDecremented, lhs); - Value quotient = builder.create(loc, dividend, rhs); - Value correctedQuotient = builder.create(loc, noneCst, quotient); + Value quotient = builder.create(loc, dividend, rhs); + Value correctedQuotient = + builder.create(loc, noneCst, quotient); Value result = builder.create(loc, negative, correctedQuotient, quotient); return result; @@ -165,27 +168,26 @@ auto rhs = visit(expr.getRHS()); assert(lhs && rhs && "unexpected affine expr lowering failure"); - Value zeroCst = builder.create(loc, 0); - Value oneCst = builder.create(loc, 1); - Value nonPositive = - builder.create(loc, CmpIPredicate::sle, lhs, zeroCst); - Value negated = builder.create(loc, zeroCst, lhs); - Value decremented = builder.create(loc, lhs, oneCst); + Value zeroCst = builder.create(loc, 0); + Value oneCst = builder.create(loc, 1); + Value nonPositive = builder.create( + loc, arith::CmpIPredicate::sle, lhs, zeroCst); + Value negated = builder.create(loc, zeroCst, lhs); + Value decremented = builder.create(loc, lhs, oneCst); Value dividend = builder.create(loc, nonPositive, negated, decremented); - Value quotient = builder.create(loc, dividend, rhs); - Value negatedQuotient = builder.create(loc, zeroCst, quotient); - Value incrementedQuotient = builder.create(loc, quotient, oneCst); + Value quotient = builder.create(loc, dividend, rhs); + Value negatedQuotient = + builder.create(loc, zeroCst, quotient); + Value incrementedQuotient = + builder.create(loc, quotient, oneCst); Value result = builder.create(loc, nonPositive, negatedQuotient, incrementedQuotient); return result; } Value visitConstantExpr(AffineConstantExpr expr) { - auto valueAttr = - builder.getIntegerAttr(builder.getIndexType(), expr.getValue()); - auto op = - builder.create(loc, builder.getIndexType(), valueAttr); + auto op = builder.create(loc, expr.getValue()); return op.getResult(); } @@ -242,20 +244,21 @@ /// comparison to perform, "lt" for "min", "gt" for "max" and is used for the /// `cmpi` operation followed by the `select` operation: /// -/// %cond = cmpi "predicate" %v0, %v1 +/// %cond = arith.cmpi "predicate" %v0, %v1 /// %result = select %cond, %v0, %v1 /// /// Multiple values are scanned in a linear sequence. This creates a data /// dependences that wouldn't exist in a tree reduction, but is easier to /// recognize as a reduction by the subsequent passes. -static Value buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate, +static Value buildMinMaxReductionSeq(Location loc, + arith::CmpIPredicate predicate, ValueRange values, OpBuilder &builder) { assert(!llvm::empty(values) && "empty min/max chain"); auto valueIt = values.begin(); Value value = *valueIt++; for (; valueIt != values.end(); ++valueIt) { - auto cmpOp = builder.create(loc, predicate, value, *valueIt); + auto cmpOp = builder.create(loc, predicate, value, *valueIt); value = builder.create(loc, cmpOp.getResult(), value, *valueIt); } @@ -267,7 +270,8 @@ static Value lowerAffineMapMax(OpBuilder &builder, Location loc, AffineMap map, ValueRange operands) { if (auto values = expandAffineMap(builder, loc, map, operands)) - return buildMinMaxReductionSeq(loc, CmpIPredicate::sgt, *values, builder); + return buildMinMaxReductionSeq(loc, arith::CmpIPredicate::sgt, *values, + builder); return nullptr; } @@ -276,7 +280,8 @@ static Value lowerAffineMapMin(OpBuilder &builder, Location loc, AffineMap map, ValueRange operands) { if (auto values = expandAffineMap(builder, loc, map, operands)) - return buildMinMaxReductionSeq(loc, CmpIPredicate::slt, *values, builder); + return buildMinMaxReductionSeq(loc, arith::CmpIPredicate::slt, *values, + builder); return nullptr; } @@ -356,7 +361,7 @@ Location loc = op.getLoc(); Value lowerBound = lowerAffineLowerBound(op, rewriter); Value upperBound = lowerAffineUpperBound(op, rewriter); - Value step = rewriter.create(loc, op.getStep()); + Value step = rewriter.create(loc, op.getStep()); auto scfForOp = rewriter.create(loc, lowerBound, upperBound, step, op.getIterOperands()); rewriter.eraseBlock(scfForOp.getBody()); @@ -399,7 +404,7 @@ } steps.reserve(op.steps().size()); for (Attribute step : op.steps()) - steps.push_back(rewriter.create( + steps.push_back(rewriter.create( loc, step.cast().getInt())); // Get the terminator op. @@ -475,7 +480,7 @@ // Now we just have to handle the condition logic. auto integerSet = op.getIntegerSet(); - Value zeroConstant = rewriter.create(loc, 0); + Value zeroConstant = rewriter.create(loc, 0); SmallVector operands(op.getOperands()); auto operandsRef = llvm::makeArrayRef(operands); @@ -492,14 +497,17 @@ operandsRef.drop_front(numDims)); if (!affResult) return failure(); - auto pred = isEquality ? CmpIPredicate::eq : CmpIPredicate::sge; + auto pred = + isEquality ? arith::CmpIPredicate::eq : arith::CmpIPredicate::sge; Value cmpVal = - rewriter.create(loc, pred, affResult, zeroConstant); - cond = - cond ? rewriter.create(loc, cond, cmpVal).getResult() : cmpVal; + rewriter.create(loc, pred, affResult, zeroConstant); + cond = cond + ? rewriter.create(loc, cond, cmpVal).getResult() + : cmpVal; } cond = cond ? cond - : rewriter.create(loc, /*value=*/1, /*width=*/1); + : rewriter.create(loc, /*value=*/1, + /*width=*/1); bool hasElseRegion = !op.elseRegion().empty(); auto ifOp = rewriter.create(loc, op.getResultTypes(), cond, @@ -750,8 +758,9 @@ populateAffineToStdConversionPatterns(patterns); populateAffineToVectorConversionPatterns(patterns); ConversionTarget target(getContext()); - target.addLegalDialect(); + target + .addLegalDialect(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); diff --git a/mlir/lib/Conversion/AffineToStandard/CMakeLists.txt b/mlir/lib/Conversion/AffineToStandard/CMakeLists.txt --- a/mlir/lib/Conversion/AffineToStandard/CMakeLists.txt +++ b/mlir/lib/Conversion/AffineToStandard/CMakeLists.txt @@ -12,6 +12,7 @@ LINK_LIBS PUBLIC MLIRAffine + MLIRArithmetic MLIRMemRef MLIRSCF MLIRPass diff --git a/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp b/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp @@ -0,0 +1,304 @@ +//===- ArithmeticToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" +#include "../PassDetail.h" +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/TypeUtilities.h" + +using namespace mlir; + +namespace { + +//===----------------------------------------------------------------------===// +// Straightforward Op Lowerings +//===----------------------------------------------------------------------===// + +using AddIOpLowering = VectorConvertToLLVMPattern; +using SubIOpLowering = VectorConvertToLLVMPattern; +using MulIOpLowering = VectorConvertToLLVMPattern; +using DivUIOpLowering = + VectorConvertToLLVMPattern; +using DivSIOpLowering = + VectorConvertToLLVMPattern; +using RemUIOpLowering = + VectorConvertToLLVMPattern; +using RemSIOpLowering = + VectorConvertToLLVMPattern; +using AndIOpLowering = VectorConvertToLLVMPattern; +using OrIOpLowering = VectorConvertToLLVMPattern; +using XOrIOpLowering = VectorConvertToLLVMPattern; +using ShLIOpLowering = VectorConvertToLLVMPattern; +using ShRUIOpLowering = + VectorConvertToLLVMPattern; +using ShRSIOpLowering = + VectorConvertToLLVMPattern; +using NegFOpLowering = VectorConvertToLLVMPattern; +using AddFOpLowering = VectorConvertToLLVMPattern; +using SubFOpLowering = VectorConvertToLLVMPattern; +using MulFOpLowering = VectorConvertToLLVMPattern; +using DivFOpLowering = VectorConvertToLLVMPattern; +using RemFOpLowering = VectorConvertToLLVMPattern; +using ExtUIOpLowering = + VectorConvertToLLVMPattern; +using ExtSIOpLowering = + VectorConvertToLLVMPattern; +using ExtFOpLowering = VectorConvertToLLVMPattern; +using TruncIOpLowering = + VectorConvertToLLVMPattern; +using TruncFOpLowering = + VectorConvertToLLVMPattern; +using UIToFPOpLowering = + VectorConvertToLLVMPattern; +using SIToFPOpLowering = + VectorConvertToLLVMPattern; +using FPToUIOpLowering = + VectorConvertToLLVMPattern; +using FPToSIOpLowering = + VectorConvertToLLVMPattern; +using BitcastOpLowering = + VectorConvertToLLVMPattern; + +//===----------------------------------------------------------------------===// +// Op Lowering Patterns +//===----------------------------------------------------------------------===// + +/// Directly lower to LLVM op. +struct ConstantOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// The lowering of index_cast becomes an integer conversion since index +/// becomes an integer. If the bit width of the source and target integer +/// types is the same, just erase the cast. If the target type is wider, +/// sign-extend the value, otherwise truncate it. +struct IndexCastOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +struct CmpIOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +struct CmpFOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// ConstantOpLowering +//===----------------------------------------------------------------------===// + +LogicalResult +ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(), + adaptor.getOperands(), + *getTypeConverter(), rewriter); +} + +//===----------------------------------------------------------------------===// +// IndexCastOpLowering +//===----------------------------------------------------------------------===// + +LogicalResult IndexCastOpLowering::matchAndRewrite( + arith::IndexCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto targetType = typeConverter->convertType(op.getResult().getType()); + auto targetElementType = + typeConverter->convertType(getElementTypeOrSelf(op.getResult())) + .cast(); + auto sourceElementType = + getElementTypeOrSelf(adaptor.in()).cast(); + unsigned targetBits = targetElementType.getWidth(); + unsigned sourceBits = sourceElementType.getWidth(); + + if (targetBits == sourceBits) + rewriter.replaceOp(op, adaptor.in()); + else if (targetBits < sourceBits) + rewriter.replaceOpWithNewOp(op, targetType, adaptor.in()); + else + rewriter.replaceOpWithNewOp(op, targetType, adaptor.in()); + return success(); +} + +//===----------------------------------------------------------------------===// +// CmpIOpLowering +//===----------------------------------------------------------------------===// + +// Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums +// share numerical values so just cast. +template +static LLVMPredType convertCmpPredicate(PredType pred) { + return static_cast(pred); +} + +LogicalResult +CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto operandType = adaptor.lhs().getType(); + auto resultType = op.getResult().getType(); + + // Handle the scalar and 1D vector cases. + if (!operandType.isa()) { + rewriter.replaceOpWithNewOp( + op, typeConverter->convertType(resultType), + convertCmpPredicate(op.getPredicate()), + adaptor.lhs(), adaptor.rhs()); + return success(); + } + + auto vectorType = resultType.dyn_cast(); + if (!vectorType) + return rewriter.notifyMatchFailure(op, "expected vector result type"); + + return LLVM::detail::handleMultidimensionalVectors( + op.getOperation(), adaptor.getOperands(), *getTypeConverter(), + [&](Type llvm1DVectorTy, ValueRange operands) { + OpAdaptor adaptor(operands); + return rewriter.create( + op.getLoc(), llvm1DVectorTy, + convertCmpPredicate(op.getPredicate()), + adaptor.lhs(), adaptor.rhs()); + }, + rewriter); + + return success(); +} + +//===----------------------------------------------------------------------===// +// CmpFOpLowering +//===----------------------------------------------------------------------===// + +LogicalResult +CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto operandType = adaptor.lhs().getType(); + auto resultType = op.getResult().getType(); + + // Handle the scalar and 1D vector cases. + if (!operandType.isa()) { + rewriter.replaceOpWithNewOp( + op, typeConverter->convertType(resultType), + convertCmpPredicate(op.getPredicate()), + adaptor.lhs(), adaptor.rhs()); + return success(); + } + + auto vectorType = resultType.dyn_cast(); + if (!vectorType) + return rewriter.notifyMatchFailure(op, "expected vector result type"); + + return LLVM::detail::handleMultidimensionalVectors( + op.getOperation(), adaptor.getOperands(), *getTypeConverter(), + [&](Type llvm1DVectorTy, ValueRange operands) { + OpAdaptor adaptor(operands); + return rewriter.create( + op.getLoc(), llvm1DVectorTy, + convertCmpPredicate(op.getPredicate()), + adaptor.lhs(), adaptor.rhs()); + }, + rewriter); +} + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +namespace { +struct ConvertArithmeticToLLVMPass + : public ConvertArithmeticToLLVMBase { + ConvertArithmeticToLLVMPass() = default; + + void runOnFunction() override { + LLVMConversionTarget target(getContext()); + RewritePatternSet patterns(&getContext()); + + LowerToLLVMOptions options(&getContext()); + if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) + options.overrideIndexBitwidth(indexBitwidth); + + LLVMTypeConverter converter(&getContext(), options); + mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, + patterns); + + if (failed( + applyPartialConversion(getFunction(), target, std::move(patterns)))) + signalPassFailure(); + } +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// Pattern Population +//===----------------------------------------------------------------------===// + +void mlir::arith::populateArithmeticToLLVMConversionPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns) { + // clang-format off + patterns.add< + ConstantOpLowering, + AddIOpLowering, + SubIOpLowering, + MulIOpLowering, + DivUIOpLowering, + DivSIOpLowering, + RemUIOpLowering, + RemSIOpLowering, + AndIOpLowering, + OrIOpLowering, + XOrIOpLowering, + ShLIOpLowering, + ShRUIOpLowering, + ShRSIOpLowering, + NegFOpLowering, + AddFOpLowering, + SubFOpLowering, + MulFOpLowering, + DivFOpLowering, + RemFOpLowering, + ExtUIOpLowering, + ExtSIOpLowering, + ExtFOpLowering, + TruncIOpLowering, + TruncFOpLowering, + UIToFPOpLowering, + SIToFPOpLowering, + FPToUIOpLowering, + FPToSIOpLowering, + IndexCastOpLowering, + BitcastOpLowering, + CmpIOpLowering, + CmpFOpLowering + >(converter); + // clang-format on +} + +std::unique_ptr mlir::arith::createConvertArithmeticToLLVMPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Conversion/ArithmeticToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ArithmeticToLLVM/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArithmeticToLLVM/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_conversion_library(MLIRArithmeticToLLVM + ArithmeticToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithmeticToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRLLVMCommonConversion + MLIRLLVMIR + ) diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp @@ -0,0 +1,826 @@ +//===- ArithmeticToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" +#include "../PassDetail.h" +#include "../SPIRVCommon/Pattern.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "arith-to-spirv-pattern" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Operation Conversion +//===----------------------------------------------------------------------===// + +namespace { + +/// Converts composite arith.constant operation to spv.Constant. +struct ConstantCompositeOpPattern final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts scalar arith.constant operation to spv.Constant. +struct ConstantScalarOpPattern final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts arith.remsi to SPIR-V ops. +/// +/// This cannot be merged into the template unary/binary pattern due to Vulkan +/// restrictions over spv.SRem and spv.SMod. +struct RemSIOpPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts bitwise operations to SPIR-V operations. This is a special pattern +/// other than the BinaryOpPatternPattern because if the operands are boolean +/// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For +/// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. +template +struct BitwiseOpPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts arith.xori to SPIR-V operations. +struct XOrIOpLogicalPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts arith.xori to SPIR-V operations if the type of source is i1 or +/// vector of i1. +struct XOrIOpBooleanPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts arith.uitofp to spv.Select if the type of source is i1 or vector of +/// i1. +struct UIToFPI1Pattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts arith.extui to spv.Select if the type of source is i1 or vector of +/// i1. +struct ExtUII1Pattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts arith.trunci to spv.Select if the type of result is i1 or vector of +/// i1. +struct TruncII1Pattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts type-casting standard operations to SPIR-V operations. +template +struct TypeCastingOpPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts integer compare operation on i1 type operands to SPIR-V ops. +class CmpIOpBooleanPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts integer compare operation to SPIR-V ops. +class CmpIOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts floating-point comparison operations to SPIR-V ops. +class CmpFOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts floating point NaN check to SPIR-V ops. This pattern requires +/// Kernel capability. +class CmpFOpNanKernelPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts floating point NaN check to SPIR-V ops. This pattern does not +/// require additional capability. +class CmpFOpNanNonePattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// Conversion Helpers +//===----------------------------------------------------------------------===// + +/// Converts the given `srcAttr` into a boolean attribute if it holds an +/// integral value. Returns null attribute if conversion fails. +static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) { + if (auto boolAttr = srcAttr.dyn_cast()) + return boolAttr; + if (auto intAttr = srcAttr.dyn_cast()) + return builder.getBoolAttr(intAttr.getValue().getBoolValue()); + return BoolAttr(); +} + +/// Converts the given `srcAttr` to a new attribute of the given `dstType`. +/// Returns null attribute if conversion fails. +static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, + Builder builder) { + // If the source number uses less active bits than the target bitwidth, then + // it should be safe to convert. + if (srcAttr.getValue().isIntN(dstType.getWidth())) + return builder.getIntegerAttr(dstType, srcAttr.getInt()); + + // XXX: Try again by interpreting the source number as a signed value. + // Although integers in the standard dialect are signless, they can represent + // a signed number. It's the operation decides how to interpret. This is + // dangerous, but it seems there is no good way of handling this if we still + // want to change the bitwidth. Emit a message at least. + if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) { + auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt()); + LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '" + << dstAttr << "' for type '" << dstType << "'\n"); + return dstAttr; + } + + LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr + << "' illegal: cannot fit into target type '" + << dstType << "'\n"); + return IntegerAttr(); +} + +/// Converts the given `srcAttr` to a new attribute of the given `dstType`. +/// Returns null attribute if `dstType` is not 32-bit or conversion fails. +static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, + Builder builder) { + // Only support converting to float for now. + if (!dstType.isF32()) + return FloatAttr(); + + // Try to convert the source floating-point number to single precision. + APFloat dstVal = srcAttr.getValue(); + bool losesInfo = false; + APFloat::opStatus status = + dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo); + if (status != APFloat::opOK || losesInfo) { + LLVM_DEBUG(llvm::dbgs() + << srcAttr << " illegal: cannot fit into converted type '" + << dstType << "'\n"); + return FloatAttr(); + } + + return builder.getF32FloatAttr(dstVal.convertToFloat()); +} + +/// Returns true if the given `type` is a boolean scalar or vector type. +static bool isBoolScalarOrVector(Type type) { + if (type.isInteger(1)) + return true; + if (auto vecType = type.dyn_cast()) + return vecType.getElementType().isInteger(1); + return false; +} + +//===----------------------------------------------------------------------===// +// ConstantOp with composite type +//===----------------------------------------------------------------------===// + +LogicalResult ConstantCompositeOpPattern::matchAndRewrite( + arith::ConstantOp constOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto srcType = constOp.getType().dyn_cast(); + if (!srcType) + return failure(); + + // std.constant should only have vector or tenor types. + assert((srcType.isa())); + + auto dstType = getTypeConverter()->convertType(srcType); + if (!dstType) + return failure(); + + auto dstElementsAttr = constOp.value().dyn_cast(); + ShapedType dstAttrType = dstElementsAttr.getType(); + if (!dstElementsAttr) + return failure(); + + // If the composite type has more than one dimensions, perform linearization. + if (srcType.getRank() > 1) { + if (srcType.isa()) { + dstAttrType = RankedTensorType::get(srcType.getNumElements(), + srcType.getElementType()); + dstElementsAttr = dstElementsAttr.reshape(dstAttrType); + } else { + // TODO: add support for large vectors. + return failure(); + } + } + + Type srcElemType = srcType.getElementType(); + Type dstElemType; + // Tensor types are converted to SPIR-V array types; vector types are + // converted to SPIR-V vector/array types. + if (auto arrayType = dstType.dyn_cast()) + dstElemType = arrayType.getElementType(); + else + dstElemType = dstType.cast().getElementType(); + + // If the source and destination element types are different, perform + // attribute conversion. + if (srcElemType != dstElemType) { + SmallVector elements; + if (srcElemType.isa()) { + for (FloatAttr srcAttr : dstElementsAttr.getValues()) { + FloatAttr dstAttr = + convertFloatAttr(srcAttr, dstElemType.cast(), rewriter); + if (!dstAttr) + return failure(); + elements.push_back(dstAttr); + } + } else if (srcElemType.isInteger(1)) { + return failure(); + } else { + for (IntegerAttr srcAttr : dstElementsAttr.getValues()) { + IntegerAttr dstAttr = convertIntegerAttr( + srcAttr, dstElemType.cast(), rewriter); + if (!dstAttr) + return failure(); + elements.push_back(dstAttr); + } + } + + // Unfortunately, we cannot use dialect-specific types for element + // attributes; element attributes only works with builtin types. So we need + // to prepare another converted builtin types for the destination elements + // attribute. + if (dstAttrType.isa()) + dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType); + else + dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType); + + dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements); + } + + rewriter.replaceOpWithNewOp(constOp, dstType, + dstElementsAttr); + return success(); +} + +//===----------------------------------------------------------------------===// +// ConstantOp with scalar type +//===----------------------------------------------------------------------===// + +LogicalResult ConstantScalarOpPattern::matchAndRewrite( + arith::ConstantOp constOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Type srcType = constOp.getType(); + if (!srcType.isIntOrIndexOrFloat()) + return failure(); + + Type dstType = getTypeConverter()->convertType(srcType); + if (!dstType) + return failure(); + + // Floating-point types. + if (srcType.isa()) { + auto srcAttr = constOp.value().cast(); + auto dstAttr = srcAttr; + + // Floating-point types not supported in the target environment are all + // converted to float type. + if (srcType != dstType) { + dstAttr = convertFloatAttr(srcAttr, dstType.cast(), rewriter); + if (!dstAttr) + return failure(); + } + + rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); + return success(); + } + + // Bool type. + if (srcType.isInteger(1)) { + // std.constant can use 0/1 instead of true/false for i1 values. We need to + // handle that here. + auto dstAttr = convertBoolAttr(constOp.value(), rewriter); + if (!dstAttr) + return failure(); + rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); + return success(); + } + + // IndexType or IntegerType. Index values are converted to 32-bit integer + // values when converting to SPIR-V. + auto srcAttr = constOp.value().cast(); + auto dstAttr = + convertIntegerAttr(srcAttr, dstType.cast(), rewriter); + if (!dstAttr) + return failure(); + rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); + return success(); +} + +//===----------------------------------------------------------------------===// +// RemSIOpPattern +//===----------------------------------------------------------------------===// + +/// Returns signed remainder for `lhs` and `rhs` and lets the result follow +/// the sign of `signOperand`. +/// +/// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment +/// spec, "for the OpSRem and OpSMod instructions, if either operand is negative +/// the result is undefined." So we cannot directly use spv.SRem/spv.SMod +/// if either operand can be negative. Emulate it via spv.UMod. +static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, + Value signOperand, OpBuilder &builder) { + assert(lhs.getType() == rhs.getType()); + assert(lhs == signOperand || rhs == signOperand); + + Type type = lhs.getType(); + + // Calculate the remainder with spv.UMod. + Value lhsAbs = builder.create(loc, type, lhs); + Value rhsAbs = builder.create(loc, type, rhs); + Value abs = builder.create(loc, lhsAbs, rhsAbs); + + // Fix the sign. + Value isPositive; + if (lhs == signOperand) + isPositive = builder.create(loc, lhs, lhsAbs); + else + isPositive = builder.create(loc, rhs, rhsAbs); + Value absNegate = builder.create(loc, type, abs); + return builder.create(loc, type, isPositive, abs, absNegate); +} + +LogicalResult +RemSIOpPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value result = emulateSignedRemainder(op.getLoc(), adaptor.getOperands()[0], + adaptor.getOperands()[1], + adaptor.getOperands()[0], rewriter); + rewriter.replaceOp(op, result); + + return success(); +} + +//===----------------------------------------------------------------------===// +// BitwiseOpPattern +//===----------------------------------------------------------------------===// + +template +LogicalResult +BitwiseOpPattern::matchAndRewrite( + Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + assert(adaptor.getOperands().size() == 2); + auto dstType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!dstType) + return failure(); + if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) { + rewriter.template replaceOpWithNewOp(op, dstType, + adaptor.getOperands()); + } else { + rewriter.template replaceOpWithNewOp(op, dstType, + adaptor.getOperands()); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// XOrIOpLogicalPattern +//===----------------------------------------------------------------------===// + +LogicalResult XOrIOpLogicalPattern::matchAndRewrite( + arith::XOrIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + assert(adaptor.getOperands().size() == 2); + + if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) + return failure(); + + auto dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) + return failure(); + rewriter.replaceOpWithNewOp(op, dstType, + adaptor.getOperands()); + + return success(); +} + +//===----------------------------------------------------------------------===// +// XOrIOpBooleanPattern +//===----------------------------------------------------------------------===// + +LogicalResult XOrIOpBooleanPattern::matchAndRewrite( + arith::XOrIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + assert(adaptor.getOperands().size() == 2); + + if (!isBoolScalarOrVector(adaptor.getOperands().front().getType())) + return failure(); + + auto dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) + return failure(); + rewriter.replaceOpWithNewOp(op, dstType, + adaptor.getOperands()); + return success(); +} + +//===----------------------------------------------------------------------===// +// UIToFPI1Pattern +//===----------------------------------------------------------------------===// + +LogicalResult +UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto srcType = adaptor.getOperands().front().getType(); + if (!isBoolScalarOrVector(srcType)) + return failure(); + + auto dstType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Location loc = op.getLoc(); + Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); + Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); + rewriter.template replaceOpWithNewOp( + op, dstType, adaptor.getOperands().front(), one, zero); + return success(); +} + +//===----------------------------------------------------------------------===// +// ExtUII1Pattern +//===----------------------------------------------------------------------===// + +LogicalResult +ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto srcType = adaptor.getOperands().front().getType(); + if (!isBoolScalarOrVector(srcType)) + return failure(); + + auto dstType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Location loc = op.getLoc(); + Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); + Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); + rewriter.template replaceOpWithNewOp( + op, dstType, adaptor.getOperands().front(), one, zero); + return success(); +} + +//===----------------------------------------------------------------------===// +// TruncII1Pattern +//===----------------------------------------------------------------------===// + +LogicalResult +TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto dstType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!isBoolScalarOrVector(dstType)) + return failure(); + + Location loc = op.getLoc(); + auto srcType = adaptor.getOperands().front().getType(); + // Check if (x & 1) == 1. + Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); + Value maskedSrc = rewriter.create( + loc, srcType, adaptor.getOperands()[0], mask); + Value isOne = rewriter.create(loc, maskedSrc, mask); + + Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); + Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); + rewriter.replaceOpWithNewOp(op, dstType, isOne, one, zero); + return success(); +} + +//===----------------------------------------------------------------------===// +// TypeCastingOpPattern +//===----------------------------------------------------------------------===// + +template +LogicalResult TypeCastingOpPattern::matchAndRewrite( + Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + assert(adaptor.getOperands().size() == 1); + auto srcType = adaptor.getOperands().front().getType(); + auto dstType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) + return failure(); + if (dstType == srcType) { + // Due to type conversion, we are seeing the same source and target type. + // Then we can just erase this operation by forwarding its operand. + rewriter.replaceOp(op, adaptor.getOperands().front()); + } else { + rewriter.template replaceOpWithNewOp(op, dstType, + adaptor.getOperands()); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// CmpIOpBooleanPattern +//===----------------------------------------------------------------------===// + +LogicalResult CmpIOpBooleanPattern::matchAndRewrite( + arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Type operandType = op.lhs().getType(); + if (!isBoolScalarOrVector(operandType)) + return failure(); + + switch (op.getPredicate()) { +#define DISPATCH(cmpPredicate, spirvOp) \ + case cmpPredicate: \ + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), \ + adaptor.lhs(), adaptor.rhs()); \ + return success(); + + DISPATCH(arith::CmpIPredicate::eq, spirv::LogicalEqualOp); + DISPATCH(arith::CmpIPredicate::ne, spirv::LogicalNotEqualOp); + +#undef DISPATCH + default:; + } + return failure(); +} + +//===----------------------------------------------------------------------===// +// CmpIOpPattern +//===----------------------------------------------------------------------===// + +LogicalResult +CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Type operandType = op.lhs().getType(); + if (isBoolScalarOrVector(operandType)) + return failure(); + + switch (op.getPredicate()) { +#define DISPATCH(cmpPredicate, spirvOp) \ + case cmpPredicate: \ + if (spirvOp::template hasTrait() && \ + operandType != this->getTypeConverter()->convertType(operandType)) { \ + return op.emitError( \ + "bitwidth emulation is not implemented yet on unsigned op"); \ + } \ + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), \ + adaptor.lhs(), adaptor.rhs()); \ + return success(); + + DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp); + DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp); + DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp); + DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp); + DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp); + DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp); + DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp); + DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp); + DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp); + DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp); + +#undef DISPATCH + } + return failure(); +} + +//===----------------------------------------------------------------------===// +// CmpFOpPattern +//===----------------------------------------------------------------------===// + +LogicalResult +CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + switch (op.getPredicate()) { +#define DISPATCH(cmpPredicate, spirvOp) \ + case cmpPredicate: \ + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), \ + adaptor.lhs(), adaptor.rhs()); \ + return success(); + + // Ordered. + DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp); + DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); + DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); + DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp); + DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); + DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp); + // Unordered. + DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp); + DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); + DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); + DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp); + DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); + DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp); + +#undef DISPATCH + + default: + break; + } + return failure(); +} + +//===----------------------------------------------------------------------===// +// CmpFOpNanKernelPattern +//===----------------------------------------------------------------------===// + +LogicalResult CmpFOpNanKernelPattern::matchAndRewrite( + arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + if (op.getPredicate() == arith::CmpFPredicate::ORD) { + rewriter.replaceOpWithNewOp(op, adaptor.lhs(), + adaptor.rhs()); + return success(); + } + + if (op.getPredicate() == arith::CmpFPredicate::UNO) { + rewriter.replaceOpWithNewOp(op, adaptor.lhs(), + adaptor.rhs()); + return success(); + } + + return failure(); +} + +//===----------------------------------------------------------------------===// +// CmpFOpNanNonePattern +//===----------------------------------------------------------------------===// + +LogicalResult CmpFOpNanNonePattern::matchAndRewrite( + arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + if (op.getPredicate() != arith::CmpFPredicate::ORD && + op.getPredicate() != arith::CmpFPredicate::UNO) + return failure(); + + Location loc = op.getLoc(); + + Value lhsIsNan = rewriter.create(loc, adaptor.lhs()); + Value rhsIsNan = rewriter.create(loc, adaptor.rhs()); + + Value replace = rewriter.create(loc, lhsIsNan, rhsIsNan); + if (op.getPredicate() == arith::CmpFPredicate::ORD) + replace = rewriter.create(loc, replace); + + rewriter.replaceOp(op, replace); + return success(); +} + +//===----------------------------------------------------------------------===// +// Pattern Population +//===----------------------------------------------------------------------===// + +void mlir::arith::populateArithmeticToSPIRVPatterns( + SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { + // clang-format off + patterns.add< + ConstantCompositeOpPattern, + ConstantScalarOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + RemSIOpPattern, + BitwiseOpPattern, + BitwiseOpPattern, + XOrIOpLogicalPattern, XOrIOpBooleanPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + TypeCastingOpPattern, ExtUII1Pattern, + TypeCastingOpPattern, + TypeCastingOpPattern, + TypeCastingOpPattern, TruncII1Pattern, + TypeCastingOpPattern, + TypeCastingOpPattern, UIToFPI1Pattern, + TypeCastingOpPattern, + TypeCastingOpPattern, + TypeCastingOpPattern, + CmpIOpBooleanPattern, CmpIOpPattern, + CmpFOpNanNonePattern, CmpFOpPattern + >(typeConverter, patterns.getContext()); + // clang-format on + + // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel + // capability is available. + patterns.add(typeConverter, patterns.getContext(), + /*benefit=*/2); +} + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +namespace { +struct ConvertArithmeticToSPIRVPass + : public ConvertArithmeticToSPIRVBase { + void runOnFunction() override { + auto module = getOperation()->getParentOfType(); + auto targetAttr = spirv::lookupTargetEnvOrDefault(module); + auto target = SPIRVConversionTarget::get(targetAttr); + + SPIRVTypeConverter::Options options; + options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes; + SPIRVTypeConverter typeConverter(targetAttr, options); + + RewritePatternSet patterns(&getContext()); + mlir::arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); + + if (failed(applyPartialConversion(getOperation(), *target, + std::move(patterns)))) + signalPassFailure(); + } +}; +} // end anonymous namespace + +std::unique_ptr mlir::arith::createConvertArithmeticToSPIRVPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/ArithmeticToSPIRV/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_conversion_library(MLIRArithmeticToSPIRV + ArithmeticToSPIRV.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithmeticToSPIRV + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRSPIRVConversion + MLIRSPIRV + ) diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -12,6 +12,7 @@ #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -903,9 +904,9 @@ LogicalResult matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto count = - rewriter.create(op->getLoc(), rewriter.getI64Type(), - rewriter.getI64IntegerAttr(op.count())); + auto count = rewriter.create( + op->getLoc(), rewriter.getI64Type(), + rewriter.getI64IntegerAttr(op.count())); auto operand = adaptor.operand(); rewriter.replaceOpWithNewOp(op, TypeRange(), apiFunctionName, @@ -1008,7 +1009,8 @@ converter, ctx); ConversionTarget target(*ctx); - target.addLegalOp(); + target + .addLegalOp(); target.addLegalDialect(); // All operations from Async dialect must be lowered to the runtime API and diff --git a/mlir/lib/Conversion/AsyncToLLVM/CMakeLists.txt b/mlir/lib/Conversion/AsyncToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/AsyncToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/AsyncToLLVM/CMakeLists.txt @@ -11,6 +11,7 @@ Core LINK_LIBS PUBLIC + MLIRArithmetic MLIRAsync MLIRLLVMCommonConversion MLIRLLVMIR diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -1,4 +1,6 @@ add_subdirectory(AffineToStandard) +add_subdirectory(ArithmeticToLLVM) +add_subdirectory(ArithmeticToSPIRV) add_subdirectory(ArmNeon2dToIntr) add_subdirectory(AsyncToLLVM) add_subdirectory(ComplexToLLVM) diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp --- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp +++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp @@ -11,8 +11,10 @@ #include "../PassDetail.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" using namespace mlir; using namespace mlir::LLVM; diff --git a/mlir/lib/Conversion/ComplexToStandard/CMakeLists.txt b/mlir/lib/Conversion/ComplexToStandard/CMakeLists.txt --- a/mlir/lib/Conversion/ComplexToStandard/CMakeLists.txt +++ b/mlir/lib/Conversion/ComplexToStandard/CMakeLists.txt @@ -8,6 +8,7 @@ MLIRConversionPassIncGen LINK_LIBS PUBLIC + MLIRArithmetic MLIRComplex MLIRIR MLIRMath diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -12,6 +12,7 @@ #include #include "../PassDetail.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -33,21 +34,21 @@ Value real = rewriter.create(loc, type, adaptor.complex()); Value imag = rewriter.create(loc, type, adaptor.complex()); - Value realSqr = rewriter.create(loc, real, real); - Value imagSqr = rewriter.create(loc, imag, imag); - Value sqNorm = rewriter.create(loc, realSqr, imagSqr); + Value realSqr = rewriter.create(loc, real, real); + Value imagSqr = rewriter.create(loc, imag, imag); + Value sqNorm = rewriter.create(loc, realSqr, imagSqr); rewriter.replaceOpWithNewOp(op, sqNorm); return success(); } }; -template +template struct ComparisonOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; using ResultCombiner = std::conditional_t::value, - AndOp, OrOp>; + arith::AndIOp, arith::OrIOp>; LogicalResult matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor, @@ -60,8 +61,10 @@ Value imagLhs = rewriter.create(loc, type, adaptor.lhs()); Value realRhs = rewriter.create(loc, type, adaptor.rhs()); Value imagRhs = rewriter.create(loc, type, adaptor.rhs()); - Value realComparison = rewriter.create(loc, p, realLhs, realRhs); - Value imagComparison = rewriter.create(loc, p, imagLhs, imagRhs); + Value realComparison = + rewriter.create(loc, p, realLhs, realRhs); + Value imagComparison = + rewriter.create(loc, p, imagLhs, imagRhs); rewriter.replaceOpWithNewOp(op, realComparison, imagComparison); @@ -138,139 +141,150 @@ // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom // // See https://dl.acm.org/citation.cfm?id=368661 for more details. - Value rhsRealImagRatio = rewriter.create(loc, rhsReal, rhsImag); - Value rhsRealImagDenom = rewriter.create( - loc, rhsImag, rewriter.create(loc, rhsRealImagRatio, rhsReal)); - Value realNumerator1 = rewriter.create( - loc, rewriter.create(loc, lhsReal, rhsRealImagRatio), lhsImag); + Value rhsRealImagRatio = + rewriter.create(loc, rhsReal, rhsImag); + Value rhsRealImagDenom = rewriter.create( + loc, rhsImag, + rewriter.create(loc, rhsRealImagRatio, rhsReal)); + Value realNumerator1 = rewriter.create( + loc, rewriter.create(loc, lhsReal, rhsRealImagRatio), + lhsImag); Value resultReal1 = - rewriter.create(loc, realNumerator1, rhsRealImagDenom); - Value imagNumerator1 = rewriter.create( - loc, rewriter.create(loc, lhsImag, rhsRealImagRatio), lhsReal); + rewriter.create(loc, realNumerator1, rhsRealImagDenom); + Value imagNumerator1 = rewriter.create( + loc, rewriter.create(loc, lhsImag, rhsRealImagRatio), + lhsReal); Value resultImag1 = - rewriter.create(loc, imagNumerator1, rhsRealImagDenom); - - Value rhsImagRealRatio = rewriter.create(loc, rhsImag, rhsReal); - Value rhsImagRealDenom = rewriter.create( - loc, rhsReal, rewriter.create(loc, rhsImagRealRatio, rhsImag)); - Value realNumerator2 = rewriter.create( - loc, lhsReal, rewriter.create(loc, lhsImag, rhsImagRealRatio)); + rewriter.create(loc, imagNumerator1, rhsRealImagDenom); + + Value rhsImagRealRatio = + rewriter.create(loc, rhsImag, rhsReal); + Value rhsImagRealDenom = rewriter.create( + loc, rhsReal, + rewriter.create(loc, rhsImagRealRatio, rhsImag)); + Value realNumerator2 = rewriter.create( + loc, lhsReal, + rewriter.create(loc, lhsImag, rhsImagRealRatio)); Value resultReal2 = - rewriter.create(loc, realNumerator2, rhsImagRealDenom); - Value imagNumerator2 = rewriter.create( - loc, lhsImag, rewriter.create(loc, lhsReal, rhsImagRealRatio)); + rewriter.create(loc, realNumerator2, rhsImagRealDenom); + Value imagNumerator2 = rewriter.create( + loc, lhsImag, + rewriter.create(loc, lhsReal, rhsImagRealRatio)); Value resultImag2 = - rewriter.create(loc, imagNumerator2, rhsImagRealDenom); + rewriter.create(loc, imagNumerator2, rhsImagRealDenom); // Consider corner cases. // Case 1. Zero denominator, numerator contains at most one NaN value. - Value zero = rewriter.create(loc, elementType, - rewriter.getZeroAttr(elementType)); - Value rhsRealAbs = rewriter.create(loc, rhsReal); - Value rhsRealIsZero = - rewriter.create(loc, CmpFPredicate::OEQ, rhsRealAbs, zero); - Value rhsImagAbs = rewriter.create(loc, rhsImag); - Value rhsImagIsZero = - rewriter.create(loc, CmpFPredicate::OEQ, rhsImagAbs, zero); - Value lhsRealIsNotNaN = - rewriter.create(loc, CmpFPredicate::ORD, lhsReal, zero); - Value lhsImagIsNotNaN = - rewriter.create(loc, CmpFPredicate::ORD, lhsImag, zero); + Value zero = rewriter.create( + loc, elementType, rewriter.getZeroAttr(elementType)); + Value rhsRealAbs = rewriter.create(loc, rhsReal); + Value rhsRealIsZero = rewriter.create( + loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero); + Value rhsImagAbs = rewriter.create(loc, rhsImag); + Value rhsImagIsZero = rewriter.create( + loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero); + Value lhsRealIsNotNaN = rewriter.create( + loc, arith::CmpFPredicate::ORD, lhsReal, zero); + Value lhsImagIsNotNaN = rewriter.create( + loc, arith::CmpFPredicate::ORD, lhsImag, zero); Value lhsContainsNotNaNValue = - rewriter.create(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); - Value resultIsInfinity = rewriter.create( + rewriter.create(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); + Value resultIsInfinity = rewriter.create( loc, lhsContainsNotNaNValue, - rewriter.create(loc, rhsRealIsZero, rhsImagIsZero)); - Value inf = rewriter.create( + rewriter.create(loc, rhsRealIsZero, rhsImagIsZero)); + Value inf = rewriter.create( loc, elementType, rewriter.getFloatAttr( elementType, APFloat::getInf(elementType.getFloatSemantics()))); - Value infWithSignOfRhsReal = rewriter.create(loc, inf, rhsReal); + Value infWithSignOfRhsReal = + rewriter.create(loc, inf, rhsReal); Value infinityResultReal = - rewriter.create(loc, infWithSignOfRhsReal, lhsReal); + rewriter.create(loc, infWithSignOfRhsReal, lhsReal); Value infinityResultImag = - rewriter.create(loc, infWithSignOfRhsReal, lhsImag); + rewriter.create(loc, infWithSignOfRhsReal, lhsImag); // Case 2. Infinite numerator, finite denominator. - Value rhsRealFinite = - rewriter.create(loc, CmpFPredicate::ONE, rhsRealAbs, inf); - Value rhsImagFinite = - rewriter.create(loc, CmpFPredicate::ONE, rhsImagAbs, inf); - Value rhsFinite = rewriter.create(loc, rhsRealFinite, rhsImagFinite); - Value lhsRealAbs = rewriter.create(loc, lhsReal); - Value lhsRealInfinite = - rewriter.create(loc, CmpFPredicate::OEQ, lhsRealAbs, inf); - Value lhsImagAbs = rewriter.create(loc, lhsImag); - Value lhsImagInfinite = - rewriter.create(loc, CmpFPredicate::OEQ, lhsImagAbs, inf); + Value rhsRealFinite = rewriter.create( + loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf); + Value rhsImagFinite = rewriter.create( + loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf); + Value rhsFinite = + rewriter.create(loc, rhsRealFinite, rhsImagFinite); + Value lhsRealAbs = rewriter.create(loc, lhsReal); + Value lhsRealInfinite = rewriter.create( + loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf); + Value lhsImagAbs = rewriter.create(loc, lhsImag); + Value lhsImagInfinite = rewriter.create( + loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf); Value lhsInfinite = - rewriter.create(loc, lhsRealInfinite, lhsImagInfinite); + rewriter.create(loc, lhsRealInfinite, lhsImagInfinite); Value infNumFiniteDenom = - rewriter.create(loc, lhsInfinite, rhsFinite); - Value one = rewriter.create( + rewriter.create(loc, lhsInfinite, rhsFinite); + Value one = rewriter.create( loc, elementType, rewriter.getFloatAttr(elementType, 1)); - Value lhsRealIsInfWithSign = rewriter.create( + Value lhsRealIsInfWithSign = rewriter.create( loc, rewriter.create(loc, lhsRealInfinite, one, zero), lhsReal); - Value lhsImagIsInfWithSign = rewriter.create( + Value lhsImagIsInfWithSign = rewriter.create( loc, rewriter.create(loc, lhsImagInfinite, one, zero), lhsImag); Value lhsRealIsInfWithSignTimesRhsReal = - rewriter.create(loc, lhsRealIsInfWithSign, rhsReal); + rewriter.create(loc, lhsRealIsInfWithSign, rhsReal); Value lhsImagIsInfWithSignTimesRhsImag = - rewriter.create(loc, lhsImagIsInfWithSign, rhsImag); - Value resultReal3 = rewriter.create( + rewriter.create(loc, lhsImagIsInfWithSign, rhsImag); + Value resultReal3 = rewriter.create( loc, inf, - rewriter.create(loc, lhsRealIsInfWithSignTimesRhsReal, - lhsImagIsInfWithSignTimesRhsImag)); + rewriter.create(loc, lhsRealIsInfWithSignTimesRhsReal, + lhsImagIsInfWithSignTimesRhsImag)); Value lhsRealIsInfWithSignTimesRhsImag = - rewriter.create(loc, lhsRealIsInfWithSign, rhsImag); + rewriter.create(loc, lhsRealIsInfWithSign, rhsImag); Value lhsImagIsInfWithSignTimesRhsReal = - rewriter.create(loc, lhsImagIsInfWithSign, rhsReal); - Value resultImag3 = rewriter.create( + rewriter.create(loc, lhsImagIsInfWithSign, rhsReal); + Value resultImag3 = rewriter.create( loc, inf, - rewriter.create(loc, lhsImagIsInfWithSignTimesRhsReal, - lhsRealIsInfWithSignTimesRhsImag)); + rewriter.create(loc, lhsImagIsInfWithSignTimesRhsReal, + lhsRealIsInfWithSignTimesRhsImag)); // Case 3: Finite numerator, infinite denominator. - Value lhsRealFinite = - rewriter.create(loc, CmpFPredicate::ONE, lhsRealAbs, inf); - Value lhsImagFinite = - rewriter.create(loc, CmpFPredicate::ONE, lhsImagAbs, inf); - Value lhsFinite = rewriter.create(loc, lhsRealFinite, lhsImagFinite); - Value rhsRealInfinite = - rewriter.create(loc, CmpFPredicate::OEQ, rhsRealAbs, inf); - Value rhsImagInfinite = - rewriter.create(loc, CmpFPredicate::OEQ, rhsImagAbs, inf); + Value lhsRealFinite = rewriter.create( + loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf); + Value lhsImagFinite = rewriter.create( + loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf); + Value lhsFinite = + rewriter.create(loc, lhsRealFinite, lhsImagFinite); + Value rhsRealInfinite = rewriter.create( + loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf); + Value rhsImagInfinite = rewriter.create( + loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf); Value rhsInfinite = - rewriter.create(loc, rhsRealInfinite, rhsImagInfinite); + rewriter.create(loc, rhsRealInfinite, rhsImagInfinite); Value finiteNumInfiniteDenom = - rewriter.create(loc, lhsFinite, rhsInfinite); - Value rhsRealIsInfWithSign = rewriter.create( + rewriter.create(loc, lhsFinite, rhsInfinite); + Value rhsRealIsInfWithSign = rewriter.create( loc, rewriter.create(loc, rhsRealInfinite, one, zero), rhsReal); - Value rhsImagIsInfWithSign = rewriter.create( + Value rhsImagIsInfWithSign = rewriter.create( loc, rewriter.create(loc, rhsImagInfinite, one, zero), rhsImag); Value rhsRealIsInfWithSignTimesLhsReal = - rewriter.create(loc, lhsReal, rhsRealIsInfWithSign); + rewriter.create(loc, lhsReal, rhsRealIsInfWithSign); Value rhsImagIsInfWithSignTimesLhsImag = - rewriter.create(loc, lhsImag, rhsImagIsInfWithSign); - Value resultReal4 = rewriter.create( + rewriter.create(loc, lhsImag, rhsImagIsInfWithSign); + Value resultReal4 = rewriter.create( loc, zero, - rewriter.create(loc, rhsRealIsInfWithSignTimesLhsReal, - rhsImagIsInfWithSignTimesLhsImag)); + rewriter.create(loc, rhsRealIsInfWithSignTimesLhsReal, + rhsImagIsInfWithSignTimesLhsImag)); Value rhsRealIsInfWithSignTimesLhsImag = - rewriter.create(loc, lhsImag, rhsRealIsInfWithSign); + rewriter.create(loc, lhsImag, rhsRealIsInfWithSign); Value rhsImagIsInfWithSignTimesLhsReal = - rewriter.create(loc, lhsReal, rhsImagIsInfWithSign); - Value resultImag4 = rewriter.create( + rewriter.create(loc, lhsReal, rhsImagIsInfWithSign); + Value resultImag4 = rewriter.create( loc, zero, - rewriter.create(loc, rhsRealIsInfWithSignTimesLhsImag, - rhsImagIsInfWithSignTimesLhsReal)); + rewriter.create(loc, rhsRealIsInfWithSignTimesLhsImag, + rhsImagIsInfWithSignTimesLhsReal)); - Value realAbsSmallerThanImagAbs = rewriter.create( - loc, CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); + Value realAbsSmallerThanImagAbs = rewriter.create( + loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); Value resultReal = rewriter.create(loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); Value resultImag = rewriter.create(loc, realAbsSmallerThanImagAbs, @@ -288,12 +302,12 @@ Value resultImagSpecialCase1 = rewriter.create( loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); - Value resultRealIsNaN = - rewriter.create(loc, CmpFPredicate::UNO, resultReal, zero); - Value resultImagIsNaN = - rewriter.create(loc, CmpFPredicate::UNO, resultImag, zero); + Value resultRealIsNaN = rewriter.create( + loc, arith::CmpFPredicate::UNO, resultReal, zero); + Value resultImagIsNaN = rewriter.create( + loc, arith::CmpFPredicate::UNO, resultImag, zero); Value resultIsNaN = - rewriter.create(loc, resultRealIsNaN, resultImagIsNaN); + rewriter.create(loc, resultRealIsNaN, resultImagIsNaN); Value resultRealWithSpecialCases = rewriter.create( loc, resultIsNaN, resultRealSpecialCase1, resultReal); Value resultImagWithSpecialCases = rewriter.create( @@ -321,9 +335,9 @@ rewriter.create(loc, elementType, adaptor.complex()); Value expReal = rewriter.create(loc, real); Value cosImag = rewriter.create(loc, imag); - Value resultReal = rewriter.create(loc, expReal, cosImag); + Value resultReal = rewriter.create(loc, expReal, cosImag); Value sinImag = rewriter.create(loc, imag); - Value resultImag = rewriter.create(loc, expReal, sinImag); + Value resultImag = rewriter.create(loc, expReal, sinImag); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); @@ -364,9 +378,9 @@ Value real = b.create(elementType, adaptor.complex()); Value imag = b.create(elementType, adaptor.complex()); - Value one = - b.create(elementType, b.getFloatAttr(elementType, 1)); - Value realPlusOne = b.create(real, one); + Value one = b.create(elementType, + b.getFloatAttr(elementType, 1)); + Value realPlusOne = b.create(real, one); Value newComplex = b.create(type, realPlusOne, imag); rewriter.replaceOpWithNewOp(op, type, newComplex); return success(); @@ -384,126 +398,162 @@ auto elementType = type.getElementType().cast(); Value lhsReal = b.create(elementType, adaptor.lhs()); - Value lhsRealAbs = b.create(lhsReal); + Value lhsRealAbs = b.create(lhsReal); Value lhsImag = b.create(elementType, adaptor.lhs()); - Value lhsImagAbs = b.create(lhsImag); + Value lhsImagAbs = b.create(lhsImag); Value rhsReal = b.create(elementType, adaptor.rhs()); - Value rhsRealAbs = b.create(rhsReal); + Value rhsRealAbs = b.create(rhsReal); Value rhsImag = b.create(elementType, adaptor.rhs()); - Value rhsImagAbs = b.create(rhsImag); + Value rhsImagAbs = b.create(rhsImag); - Value lhsRealTimesRhsReal = b.create(lhsReal, rhsReal); - Value lhsRealTimesRhsRealAbs = b.create(lhsRealTimesRhsReal); - Value lhsImagTimesRhsImag = b.create(lhsImag, rhsImag); - Value lhsImagTimesRhsImagAbs = b.create(lhsImagTimesRhsImag); - Value real = b.create(lhsRealTimesRhsReal, lhsImagTimesRhsImag); + Value lhsRealTimesRhsReal = b.create(lhsReal, rhsReal); + Value lhsRealTimesRhsRealAbs = b.create(lhsRealTimesRhsReal); + Value lhsImagTimesRhsImag = b.create(lhsImag, rhsImag); + Value lhsImagTimesRhsImagAbs = b.create(lhsImagTimesRhsImag); + Value real = + b.create(lhsRealTimesRhsReal, lhsImagTimesRhsImag); - Value lhsImagTimesRhsReal = b.create(lhsImag, rhsReal); - Value lhsImagTimesRhsRealAbs = b.create(lhsImagTimesRhsReal); - Value lhsRealTimesRhsImag = b.create(lhsReal, rhsImag); - Value lhsRealTimesRhsImagAbs = b.create(lhsRealTimesRhsImag); - Value imag = b.create(lhsImagTimesRhsReal, lhsRealTimesRhsImag); + Value lhsImagTimesRhsReal = b.create(lhsImag, rhsReal); + Value lhsImagTimesRhsRealAbs = b.create(lhsImagTimesRhsReal); + Value lhsRealTimesRhsImag = b.create(lhsReal, rhsImag); + Value lhsRealTimesRhsImagAbs = b.create(lhsRealTimesRhsImag); + Value imag = + b.create(lhsImagTimesRhsReal, lhsRealTimesRhsImag); // Handle cases where the "naive" calculation results in NaN values. - Value realIsNan = b.create(CmpFPredicate::UNO, real, real); - Value imagIsNan = b.create(CmpFPredicate::UNO, imag, imag); - Value isNan = b.create(realIsNan, imagIsNan); + Value realIsNan = + b.create(arith::CmpFPredicate::UNO, real, real); + Value imagIsNan = + b.create(arith::CmpFPredicate::UNO, imag, imag); + Value isNan = b.create(realIsNan, imagIsNan); - Value inf = b.create( + Value inf = b.create( elementType, b.getFloatAttr(elementType, APFloat::getInf(elementType.getFloatSemantics()))); // Case 1. `lhsReal` or `lhsImag` are infinite. - Value lhsRealIsInf = b.create(CmpFPredicate::OEQ, lhsRealAbs, inf); - Value lhsImagIsInf = b.create(CmpFPredicate::OEQ, lhsImagAbs, inf); - Value lhsIsInf = b.create(lhsRealIsInf, lhsImagIsInf); - Value rhsRealIsNan = b.create(CmpFPredicate::UNO, rhsReal, rhsReal); - Value rhsImagIsNan = b.create(CmpFPredicate::UNO, rhsImag, rhsImag); - Value zero = b.create(elementType, b.getZeroAttr(elementType)); - Value one = - b.create(elementType, b.getFloatAttr(elementType, 1)); + Value lhsRealIsInf = + b.create(arith::CmpFPredicate::OEQ, lhsRealAbs, inf); + Value lhsImagIsInf = + b.create(arith::CmpFPredicate::OEQ, lhsImagAbs, inf); + Value lhsIsInf = b.create(lhsRealIsInf, lhsImagIsInf); + Value rhsRealIsNan = + b.create(arith::CmpFPredicate::UNO, rhsReal, rhsReal); + Value rhsImagIsNan = + b.create(arith::CmpFPredicate::UNO, rhsImag, rhsImag); + Value zero = + b.create(elementType, b.getZeroAttr(elementType)); + Value one = b.create(elementType, + b.getFloatAttr(elementType, 1)); Value lhsRealIsInfFloat = b.create(lhsRealIsInf, one, zero); lhsReal = b.create( - lhsIsInf, b.create(lhsRealIsInfFloat, lhsReal), lhsReal); + lhsIsInf, b.create(lhsRealIsInfFloat, lhsReal), + lhsReal); Value lhsImagIsInfFloat = b.create(lhsImagIsInf, one, zero); lhsImag = b.create( - lhsIsInf, b.create(lhsImagIsInfFloat, lhsImag), lhsImag); - Value lhsIsInfAndRhsRealIsNan = b.create(lhsIsInf, rhsRealIsNan); - rhsReal = b.create(lhsIsInfAndRhsRealIsNan, - b.create(zero, rhsReal), rhsReal); - Value lhsIsInfAndRhsImagIsNan = b.create(lhsIsInf, rhsImagIsNan); - rhsImag = b.create(lhsIsInfAndRhsImagIsNan, - b.create(zero, rhsImag), rhsImag); + lhsIsInf, b.create(lhsImagIsInfFloat, lhsImag), + lhsImag); + Value lhsIsInfAndRhsRealIsNan = + b.create(lhsIsInf, rhsRealIsNan); + rhsReal = + b.create(lhsIsInfAndRhsRealIsNan, + b.create(zero, rhsReal), rhsReal); + Value lhsIsInfAndRhsImagIsNan = + b.create(lhsIsInf, rhsImagIsNan); + rhsImag = + b.create(lhsIsInfAndRhsImagIsNan, + b.create(zero, rhsImag), rhsImag); // Case 2. `rhsReal` or `rhsImag` are infinite. - Value rhsRealIsInf = b.create(CmpFPredicate::OEQ, rhsRealAbs, inf); - Value rhsImagIsInf = b.create(CmpFPredicate::OEQ, rhsImagAbs, inf); - Value rhsIsInf = b.create(rhsRealIsInf, rhsImagIsInf); - Value lhsRealIsNan = b.create(CmpFPredicate::UNO, lhsReal, lhsReal); - Value lhsImagIsNan = b.create(CmpFPredicate::UNO, lhsImag, lhsImag); + Value rhsRealIsInf = + b.create(arith::CmpFPredicate::OEQ, rhsRealAbs, inf); + Value rhsImagIsInf = + b.create(arith::CmpFPredicate::OEQ, rhsImagAbs, inf); + Value rhsIsInf = b.create(rhsRealIsInf, rhsImagIsInf); + Value lhsRealIsNan = + b.create(arith::CmpFPredicate::UNO, lhsReal, lhsReal); + Value lhsImagIsNan = + b.create(arith::CmpFPredicate::UNO, lhsImag, lhsImag); Value rhsRealIsInfFloat = b.create(rhsRealIsInf, one, zero); rhsReal = b.create( - rhsIsInf, b.create(rhsRealIsInfFloat, rhsReal), rhsReal); + rhsIsInf, b.create(rhsRealIsInfFloat, rhsReal), + rhsReal); Value rhsImagIsInfFloat = b.create(rhsImagIsInf, one, zero); rhsImag = b.create( - rhsIsInf, b.create(rhsImagIsInfFloat, rhsImag), rhsImag); - Value rhsIsInfAndLhsRealIsNan = b.create(rhsIsInf, lhsRealIsNan); - lhsReal = b.create(rhsIsInfAndLhsRealIsNan, - b.create(zero, lhsReal), lhsReal); - Value rhsIsInfAndLhsImagIsNan = b.create(rhsIsInf, lhsImagIsNan); - lhsImag = b.create(rhsIsInfAndLhsImagIsNan, - b.create(zero, lhsImag), lhsImag); - Value recalc = b.create(lhsIsInf, rhsIsInf); + rhsIsInf, b.create(rhsImagIsInfFloat, rhsImag), + rhsImag); + Value rhsIsInfAndLhsRealIsNan = + b.create(rhsIsInf, lhsRealIsNan); + lhsReal = + b.create(rhsIsInfAndLhsRealIsNan, + b.create(zero, lhsReal), lhsReal); + Value rhsIsInfAndLhsImagIsNan = + b.create(rhsIsInf, lhsImagIsNan); + lhsImag = + b.create(rhsIsInfAndLhsImagIsNan, + b.create(zero, lhsImag), lhsImag); + Value recalc = b.create(lhsIsInf, rhsIsInf); // Case 3. One of the pairwise products of left hand side with right hand // side is infinite. - Value lhsRealTimesRhsRealIsInf = - b.create(CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf); - Value lhsImagTimesRhsImagIsInf = - b.create(CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf); - Value isSpecialCase = - b.create(lhsRealTimesRhsRealIsInf, lhsImagTimesRhsImagIsInf); - Value lhsRealTimesRhsImagIsInf = - b.create(CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf); - isSpecialCase = b.create(isSpecialCase, lhsRealTimesRhsImagIsInf); - Value lhsImagTimesRhsRealIsInf = - b.create(CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf); - isSpecialCase = b.create(isSpecialCase, lhsImagTimesRhsRealIsInf); + Value lhsRealTimesRhsRealIsInf = b.create( + arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf); + Value lhsImagTimesRhsImagIsInf = b.create( + arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf); + Value isSpecialCase = b.create(lhsRealTimesRhsRealIsInf, + lhsImagTimesRhsImagIsInf); + Value lhsRealTimesRhsImagIsInf = b.create( + arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf); + isSpecialCase = + b.create(isSpecialCase, lhsRealTimesRhsImagIsInf); + Value lhsImagTimesRhsRealIsInf = b.create( + arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf); + isSpecialCase = + b.create(isSpecialCase, lhsImagTimesRhsRealIsInf); Type i1Type = b.getI1Type(); - Value notRecalc = b.create( - recalc, b.create(i1Type, b.getIntegerAttr(i1Type, 1))); - isSpecialCase = b.create(isSpecialCase, notRecalc); + Value notRecalc = b.create( + recalc, + b.create(i1Type, b.getIntegerAttr(i1Type, 1))); + isSpecialCase = b.create(isSpecialCase, notRecalc); Value isSpecialCaseAndLhsRealIsNan = - b.create(isSpecialCase, lhsRealIsNan); - lhsReal = b.create(isSpecialCaseAndLhsRealIsNan, - b.create(zero, lhsReal), lhsReal); + b.create(isSpecialCase, lhsRealIsNan); + lhsReal = + b.create(isSpecialCaseAndLhsRealIsNan, + b.create(zero, lhsReal), lhsReal); Value isSpecialCaseAndLhsImagIsNan = - b.create(isSpecialCase, lhsImagIsNan); - lhsImag = b.create(isSpecialCaseAndLhsImagIsNan, - b.create(zero, lhsImag), lhsImag); + b.create(isSpecialCase, lhsImagIsNan); + lhsImag = + b.create(isSpecialCaseAndLhsImagIsNan, + b.create(zero, lhsImag), lhsImag); Value isSpecialCaseAndRhsRealIsNan = - b.create(isSpecialCase, rhsRealIsNan); - rhsReal = b.create(isSpecialCaseAndRhsRealIsNan, - b.create(zero, rhsReal), rhsReal); + b.create(isSpecialCase, rhsRealIsNan); + rhsReal = + b.create(isSpecialCaseAndRhsRealIsNan, + b.create(zero, rhsReal), rhsReal); Value isSpecialCaseAndRhsImagIsNan = - b.create(isSpecialCase, rhsImagIsNan); - rhsImag = b.create(isSpecialCaseAndRhsImagIsNan, - b.create(zero, rhsImag), rhsImag); - recalc = b.create(recalc, isSpecialCase); - recalc = b.create(isNan, recalc); + b.create(isSpecialCase, rhsImagIsNan); + rhsImag = + b.create(isSpecialCaseAndRhsImagIsNan, + b.create(zero, rhsImag), rhsImag); + recalc = b.create(recalc, isSpecialCase); + recalc = b.create(isNan, recalc); // Recalculate real part. - lhsRealTimesRhsReal = b.create(lhsReal, rhsReal); - lhsImagTimesRhsImag = b.create(lhsImag, rhsImag); - Value newReal = b.create(lhsRealTimesRhsReal, lhsImagTimesRhsImag); - real = b.create(recalc, b.create(inf, newReal), real); + lhsRealTimesRhsReal = b.create(lhsReal, rhsReal); + lhsImagTimesRhsImag = b.create(lhsImag, rhsImag); + Value newReal = + b.create(lhsRealTimesRhsReal, lhsImagTimesRhsImag); + real = + b.create(recalc, b.create(inf, newReal), real); // Recalculate imag part. - lhsImagTimesRhsReal = b.create(lhsImag, rhsReal); - lhsRealTimesRhsImag = b.create(lhsReal, rhsImag); - Value newImag = b.create(lhsImagTimesRhsReal, lhsRealTimesRhsImag); - imag = b.create(recalc, b.create(inf, newImag), imag); + lhsImagTimesRhsReal = b.create(lhsImag, rhsReal); + lhsRealTimesRhsImag = b.create(lhsReal, rhsImag); + Value newImag = + b.create(lhsImagTimesRhsReal, lhsRealTimesRhsImag); + imag = + b.create(recalc, b.create(inf, newImag), imag); rewriter.replaceOpWithNewOp(op, type, real, imag); return success(); @@ -524,8 +574,8 @@ rewriter.create(loc, elementType, adaptor.complex()); Value imag = rewriter.create(loc, elementType, adaptor.complex()); - Value negReal = rewriter.create(loc, real); - Value negImag = rewriter.create(loc, imag); + Value negReal = rewriter.create(loc, real); + Value negImag = rewriter.create(loc, imag); rewriter.replaceOpWithNewOp(op, type, negReal, negImag); return success(); } @@ -543,13 +593,16 @@ Value real = b.create(elementType, adaptor.complex()); Value imag = b.create(elementType, adaptor.complex()); - Value zero = b.create(elementType, b.getZeroAttr(elementType)); - Value realIsZero = b.create(CmpFPredicate::OEQ, real, zero); - Value imagIsZero = b.create(CmpFPredicate::OEQ, imag, zero); - Value isZero = b.create(realIsZero, imagIsZero); + Value zero = + b.create(elementType, b.getZeroAttr(elementType)); + Value realIsZero = + b.create(arith::CmpFPredicate::OEQ, real, zero); + Value imagIsZero = + b.create(arith::CmpFPredicate::OEQ, imag, zero); + Value isZero = b.create(realIsZero, imagIsZero); auto abs = b.create(elementType, adaptor.complex()); - Value realSign = b.create(real, abs); - Value imagSign = b.create(imag, abs); + Value realSign = b.create(real, abs); + Value imagSign = b.create(imag, abs); Value sign = b.create(type, realSign, imagSign); rewriter.replaceOpWithNewOp(op, isZero, adaptor.complex(), sign); return success(); @@ -562,10 +615,10 @@ // clang-format off patterns.add< AbsOpConversion, - ComparisonOpConversion, - ComparisonOpConversion, - BinaryComplexOpConversion, - BinaryComplexOpConversion, + ComparisonOpConversion, + ComparisonOpConversion, + BinaryComplexOpConversion, + BinaryComplexOpConversion, DivOpConversion, ExpOpConversion, LogOpConversion, @@ -590,7 +643,8 @@ populateComplexToStandardConversionPatterns(patterns); ConversionTarget target(getContext()); - target.addLegalDialect(); + target.addLegalDialect(); target.addLegalOp(); if (failed(applyPartialConversion(function, target, std::move(patterns)))) signalPassFailure(); diff --git a/mlir/lib/Conversion/GPUCommon/CMakeLists.txt b/mlir/lib/Conversion/GPUCommon/CMakeLists.txt --- a/mlir/lib/Conversion/GPUCommon/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUCommon/CMakeLists.txt @@ -29,6 +29,7 @@ ${NVPTX_LIBS} LINK_LIBS PUBLIC + MLIRArithmeticToLLVM MLIRAsyncToLLVM MLIRGPUTransforms MLIRIR diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -16,6 +16,7 @@ #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "../PassDetail.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" @@ -349,6 +350,7 @@ target.addIllegalDialect(); + mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, patterns); populateVectorToLLVMConversionPatterns(converter, patterns); populateMemRefToLLVMConversionPatterns(converter, patterns); populateStdToLLVMConversionPatterns(converter, patterns); diff --git a/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt b/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt --- a/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt @@ -11,11 +11,11 @@ MLIRGPUToNVVMIncGen LINK_LIBS PUBLIC + MLIRArithmeticToLLVM MLIRGPUOps MLIRGPUToGPURuntimeTransforms MLIRLLVMCommonConversion MLIRLLVMIR - MLIRMemRef MLIRMemRefToLLVM MLIRNVVMIR MLIRPass diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -13,11 +13,13 @@ #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" @@ -169,6 +171,8 @@ populateGpuRewritePatterns(patterns); (void)applyPatternsAndFoldGreedily(m, std::move(patterns)); + mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, + llvmPatterns); populateStdToLLVMConversionPatterns(converter, llvmPatterns); populateMemRefToLLVMConversionPatterns(converter, llvmPatterns); populateGpuToNVVMConversionPatterns(converter, llvmPatterns); @@ -217,14 +221,14 @@ Identifier::get(NVVM::NVVMDialect::getKernelFuncAttrName(), &converter.getContext())); - patterns.add>(converter, "__nv_fabsf", - "__nv_fabs"); + patterns.add>(converter, "__nv_fabsf", + "__nv_fabs"); patterns.add>(converter, "__nv_atanf", "__nv_atan"); patterns.add>(converter, "__nv_atan2f", "__nv_atan2"); - patterns.add>(converter, "__nv_ceilf", - "__nv_ceil"); + patterns.add>(converter, "__nv_ceilf", + "__nv_ceil"); patterns.add>(converter, "__nv_cosf", "__nv_cos"); patterns.add>(converter, "__nv_expf", @@ -233,8 +237,8 @@ "__nv_exp2"); patterns.add>(converter, "__nv_expm1f", "__nv_expm1"); - patterns.add>(converter, "__nv_floorf", - "__nv_floor"); + patterns.add>(converter, "__nv_floorf", + "__nv_floor"); patterns.add>(converter, "__nv_logf", "__nv_log"); patterns.add>(converter, "__nv_log1pf", diff --git a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt --- a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt @@ -10,6 +10,7 @@ MLIRGPUToROCDLIncGen LINK_LIBS PUBLIC + MLIRArithmeticToLLVM MLIRGPUOps MLIRGPUToGPURuntimeTransforms MLIRLLVMCommonConversion diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" @@ -72,6 +73,8 @@ populateGpuRewritePatterns(patterns); (void)applyPatternsAndFoldGreedily(m, std::move(patterns)); + mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, + llvmPatterns); populateVectorToLLVMConversionPatterns(converter, llvmPatterns); populateVectorToROCDLConversionPatterns(converter, llvmPatterns); populateStdToLLVMConversionPatterns(converter, llvmPatterns); @@ -116,14 +119,14 @@ converter, /*allocaAddrSpace=*/5, Identifier::get(ROCDL::ROCDLDialect::getKernelFuncAttrName(), &converter.getContext())); - patterns.add>(converter, "__ocml_fabs_f32", - "__ocml_fabs_f64"); + patterns.add>(converter, "__ocml_fabs_f32", + "__ocml_fabs_f64"); patterns.add>(converter, "__ocml_atan_f32", "__ocml_atan_f64"); patterns.add>( converter, "__ocml_atan2_f32", "__ocml_atan2_f64"); - patterns.add>(converter, "__ocml_ceil_f32", - "__ocml_ceil_f64"); + patterns.add>(converter, "__ocml_ceil_f32", + "__ocml_ceil_f64"); patterns.add>(converter, "__ocml_cos_f32", "__ocml_cos_f64"); patterns.add>(converter, "__ocml_exp_f32", @@ -132,8 +135,8 @@ "__ocml_exp2_f64"); patterns.add>( converter, "__ocml_expm1_f32", "__ocml_expm1_f64"); - patterns.add>(converter, "__ocml_floor_f32", - "__ocml_floor_f64"); + patterns.add>( + converter, "__ocml_floor_f32", "__ocml_floor_f64"); patterns.add>(converter, "__ocml_log_f32", "__ocml_log_f64"); patterns.add>( diff --git a/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt --- a/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt @@ -6,13 +6,13 @@ MLIRConversionPassIncGen LINK_LIBS PUBLIC + MLIRArithmeticToSPIRV MLIRGPUOps MLIRIR MLIRPass MLIRSCFToSPIRV MLIRSPIRV MLIRSPIRVConversion - MLIRStandard MLIRStandardToSPIRV MLIRSupport MLIRTransforms diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h" #include "../PassDetail.h" +#include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h" #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h" #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h" @@ -63,6 +64,7 @@ // TODO: Change SPIR-V conversion to be progressive and remove the following // patterns. + mlir::arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); populateMemRefToSPIRVPatterns(typeConverter, patterns); populateStandardToSPIRVPatterns(typeConverter, patterns); diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -184,7 +184,8 @@ void ConvertLinalgToStandardPass::runOnOperation() { auto module = getOperation(); ConversionTarget target(getContext()); - target.addLegalDialect(); target.addLegalOp(); RewritePatternSet patterns(&getContext()); diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -18,9 +18,16 @@ using namespace mlir; namespace { +using AbsOpLowering = VectorConvertToLLVMPattern; +using CeilOpLowering = VectorConvertToLLVMPattern; +using CopySignOpLowering = + VectorConvertToLLVMPattern; using CosOpLowering = VectorConvertToLLVMPattern; using ExpOpLowering = VectorConvertToLLVMPattern; using Exp2OpLowering = VectorConvertToLLVMPattern; +using FloorOpLowering = + VectorConvertToLLVMPattern; +using FmaOpLowering = VectorConvertToLLVMPattern; using Log10OpLowering = VectorConvertToLLVMPattern; using Log2OpLowering = VectorConvertToLLVMPattern; @@ -209,10 +216,15 @@ RewritePatternSet &patterns) { // clang-format off patterns.add< + AbsOpLowering, + CeilOpLowering, + CopySignOpLowering, CosOpLowering, ExpOpLowering, Exp2OpLowering, ExpM1OpLowering, + FloorOpLowering, + FmaOpLowering, Log10OpLowering, Log1pOpLowering, Log2OpLowering, diff --git a/mlir/lib/Conversion/MathToLibm/CMakeLists.txt b/mlir/lib/Conversion/MathToLibm/CMakeLists.txt --- a/mlir/lib/Conversion/MathToLibm/CMakeLists.txt +++ b/mlir/lib/Conversion/MathToLibm/CMakeLists.txt @@ -11,6 +11,7 @@ Core LINK_LIBS PUBLIC + MLIRArithmetic MLIRMath MLIRStandardOpsTransforms ) diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/MathToLibm/MathToLibm.h" #include "../PassDetail.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" @@ -61,7 +62,7 @@ if (shape.size() != 1) return failure(); - Value result = rewriter.create( + Value result = rewriter.create( loc, DenseElementsAttr::get( vecType, FloatAttr::get(vecType.getElementType(), 0.0))); for (auto i = 0; i < shape.front(); ++i) { @@ -135,8 +136,8 @@ populateMathToLibmConversionPatterns(patterns, /*benefit=*/1); ConversionTarget target(getContext()); - target.addLegalDialect(); + target.addLegalDialect(); target.addIllegalDialect(); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "../SPIRVCommon/Pattern.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" @@ -29,31 +30,6 @@ // normal RewritePattern. namespace { - -/// Converts unary and binary standard operations to SPIR-V operations. -template -class UnaryAndBinaryOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - assert(adaptor.getOperands().size() <= 2); - auto dstType = this->getTypeConverter()->convertType(operation.getType()); - if (!dstType) - return failure(); - if (SPIRVOp::template hasTrait() && - dstType != operation.getType()) { - return operation.emitError( - "bitwidth emulation is not implemented yet on unsigned op"); - } - rewriter.template replaceOpWithNewOp(operation, dstType, - adaptor.getOperands()); - return success(); - } -}; - /// Converts math.log1p to SPIR-V ops. /// /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to @@ -76,7 +52,6 @@ return success(); } }; - } // namespace //===----------------------------------------------------------------------===// @@ -86,15 +61,19 @@ namespace mlir { void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern>( + patterns.add< + Log1pOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern>( typeConverter, patterns.getContext()); } diff --git a/mlir/lib/Conversion/OpenACCToSCF/CMakeLists.txt b/mlir/lib/Conversion/OpenACCToSCF/CMakeLists.txt --- a/mlir/lib/Conversion/OpenACCToSCF/CMakeLists.txt +++ b/mlir/lib/Conversion/OpenACCToSCF/CMakeLists.txt @@ -8,8 +8,9 @@ MLIRConversionPassIncGen LINK_LIBS PUBLIC + MLIRArithmetic MLIRIR MLIROpenACC - MLIRTransforms MLIRSCF + MLIRTransforms ) diff --git a/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp --- a/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp +++ b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp @@ -8,6 +8,7 @@ #include "../PassDetail.h" #include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -33,7 +34,7 @@ return success(); // Condition is not a constant. - if (!op.ifCond().template getDefiningOp()) { + if (!op.ifCond().template getDefiningOp()) { auto ifOp = rewriter.create(op.getLoc(), TypeRange(), op.ifCond(), false); rewriter.updateRootInPlace(op, [&]() { op.ifCondMutable().erase(0); }); diff --git a/mlir/lib/Conversion/OpenMPToLLVM/CMakeLists.txt b/mlir/lib/Conversion/OpenMPToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/OpenMPToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/OpenMPToLLVM/CMakeLists.txt @@ -12,6 +12,7 @@ Core LINK_LIBS PUBLIC + MLIRArithmeticToLLVM MLIRIR MLIRLLVMCommonConversion MLIRLLVMIR diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "../PassDetail.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" @@ -65,6 +66,7 @@ // Convert to OpenMP operations with LLVM IR dialect RewritePatternSet patterns(&getContext()); LLVMTypeConverter converter(&getContext()); + mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, patterns); populateMemRefToLLVMConversionPatterns(converter, patterns); populateStdToLLVMConversionPatterns(converter, patterns); populateOpenMPToLLVMConversionPatterns(converter, patterns); diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h --- a/mlir/lib/Conversion/PassDetail.h +++ b/mlir/lib/Conversion/PassDetail.h @@ -23,6 +23,10 @@ class OpenACCDialect; } // end namespace acc +namespace arith { +class ArithmeticDialect; +} // end namespace arith + namespace complex { class ComplexDialect; } // end namespace complex diff --git a/mlir/lib/Conversion/SCFToGPU/CMakeLists.txt b/mlir/lib/Conversion/SCFToGPU/CMakeLists.txt --- a/mlir/lib/Conversion/SCFToGPU/CMakeLists.txt +++ b/mlir/lib/Conversion/SCFToGPU/CMakeLists.txt @@ -11,6 +11,7 @@ LINK_LIBS PUBLIC MLIRAffine MLIRAffineToStandard + MLIRArithmetic MLIRComplex MLIRGPUTransforms MLIRIR diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -16,6 +16,7 @@ #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/ParallelLoopMapper.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -83,7 +84,8 @@ // Get a Value that corresponds to the loop step. If the step is an attribute, // materialize a corresponding constant using builder. static Value getOrCreateStep(AffineForOp forOp, OpBuilder &builder) { - return builder.create(forOp.getLoc(), forOp.getStep()); + return builder.create(forOp.getLoc(), + forOp.getStep()); } // Get a Value for the loop lower bound. If the value requires computation, @@ -169,8 +171,8 @@ // Return true if the value is obviously a constant "one". static bool isConstantOne(Value value) { - if (auto def = value.getDefiningOp()) - return def.getValue() == 1; + if (auto def = value.getDefiningOp()) + return def.value() == 1; return false; } @@ -194,11 +196,11 @@ return llvm::None; } - Value range = - builder.create(currentLoop.getLoc(), upperBound, lowerBound); + Value range = builder.create(currentLoop.getLoc(), + upperBound, lowerBound); Value step = getOrCreateStep(currentLoop, builder); if (!isConstantOne(step)) - range = builder.create(currentLoop.getLoc(), range, step); + range = builder.create(currentLoop.getLoc(), range, step); dims.push_back(range); lbs.push_back(lowerBound); @@ -222,9 +224,10 @@ OpBuilder builder(rootForOp.getOperation()); // Prepare the grid and block sizes for the launch operation. If there is // no loop mapped to a specific dimension, use constant "1" as its size. - Value constOne = (numBlockDims < 3 || numThreadDims < 3) - ? builder.create(rootForOp.getLoc(), 1) - : nullptr; + Value constOne = + (numBlockDims < 3 || numThreadDims < 3) + ? builder.create(rootForOp.getLoc(), 1) + : nullptr; Value gridSizeX = numBlockDims > 0 ? dims[0] : constOne; Value gridSizeY = numBlockDims > 1 ? dims[1] : constOne; Value gridSizeZ = numBlockDims > 2 ? dims[2] : constOne; @@ -265,10 +268,10 @@ : getDim3Value(launchOp.getThreadIds(), en.index() - numBlockDims); Value step = steps[en.index()]; if (!isConstantOne(step)) - id = builder.create(rootForOp.getLoc(), step, id); + id = builder.create(rootForOp.getLoc(), step, id); Value ivReplacement = - builder.create(rootForOp.getLoc(), *lbArgumentIt, id); + builder.create(rootForOp.getLoc(), *lbArgumentIt, id); en.value().replaceAllUsesWith(ivReplacement); std::advance(lbArgumentIt, 1); std::advance(stepArgumentIt, 1); @@ -314,33 +317,33 @@ /// `upperBound`. static Value deriveStaticUpperBound(Value upperBound, PatternRewriter &rewriter) { - if (auto op = upperBound.getDefiningOp()) { + if (auto op = upperBound.getDefiningOp()) { return op; } if (auto minOp = upperBound.getDefiningOp()) { for (const AffineExpr &result : minOp.map().getResults()) { if (auto constExpr = result.dyn_cast()) { - return rewriter.create(minOp.getLoc(), - constExpr.getValue()); + return rewriter.create(minOp.getLoc(), + constExpr.getValue()); } } } - if (auto multiplyOp = upperBound.getDefiningOp()) { - if (auto lhs = dyn_cast_or_null( + if (auto multiplyOp = upperBound.getDefiningOp()) { + if (auto lhs = dyn_cast_or_null( deriveStaticUpperBound(multiplyOp.getOperand(0), rewriter) .getDefiningOp())) - if (auto rhs = dyn_cast_or_null( + if (auto rhs = dyn_cast_or_null( deriveStaticUpperBound(multiplyOp.getOperand(1), rewriter) .getDefiningOp())) { // Assumptions about the upper bound of minimum computations no longer // work if multiplied by a negative value, so abort in this case. - if (lhs.getValue() < 0 || rhs.getValue() < 0) + if (lhs.value() < 0 || rhs.value() < 0) return {}; - return rewriter.create( - multiplyOp.getLoc(), lhs.getValue() * rhs.getValue()); + return rewriter.create( + multiplyOp.getLoc(), lhs.value() * rhs.value()); } } @@ -416,8 +419,9 @@ launchIndependent](Value val) -> Value { if (launchIndependent(val)) return val; - if (ConstantOp constOp = val.getDefiningOp()) - return rewriter.create(constOp.getLoc(), constOp.getValue()); + if (auto constOp = val.getDefiningOp()) + return rewriter.create(constOp.getLoc(), + constOp.value()); return {}; }; @@ -460,17 +464,17 @@ // conditional. If the lower-bound is constant or defined before the // launch, we can use it in the launch bounds. Otherwise fail. if (!launchIndependent(lowerBound) && - !isa_and_nonnull(lowerBound.getDefiningOp())) + !isa_and_nonnull(lowerBound.getDefiningOp())) return failure(); // The step must also be constant or defined outside of the loop nest. if (!launchIndependent(step) && - !isa_and_nonnull(step.getDefiningOp())) + !isa_and_nonnull(step.getDefiningOp())) return failure(); // If the upper-bound is constant or defined before the launch, we can // use it in the launch bounds directly. Otherwise try derive a bound. bool boundIsPrecise = launchIndependent(upperBound) || - isa_and_nonnull(upperBound.getDefiningOp()); + isa_and_nonnull(upperBound.getDefiningOp()); { PatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(launchOp); @@ -510,8 +514,8 @@ if (!boundIsPrecise) { // We are using an approximation, create a surrounding conditional. Value originalBound = std::get<3>(config); - CmpIOp pred = rewriter.create( - loc, CmpIPredicate::slt, newIndex, + arith::CmpIOp pred = rewriter.create( + loc, arith::CmpIPredicate::slt, newIndex, cloningMap.lookupOrDefault(originalBound)); scf::IfOp ifOp = rewriter.create(loc, pred, false); rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); @@ -595,7 +599,8 @@ // Create a launch operation. We start with bound one for all grid/block // sizes. Those will be refined later as we discover them from mappings. Location loc = parallelOp.getLoc(); - Value constantOne = rewriter.create(parallelOp.getLoc(), 1); + Value constantOne = + rewriter.create(parallelOp.getLoc(), 1); gpu::LaunchOp launchOp = rewriter.create( parallelOp.getLoc(), constantOne, constantOne, constantOne, constantOne, constantOne, constantOne); diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp @@ -10,6 +10,7 @@ #include "../PassDetail.h" #include "mlir/Conversion/SCFToGPU/SCFToGPU.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/SCF/SCF.h" diff --git a/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt b/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt --- a/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt +++ b/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt @@ -12,6 +12,7 @@ LINK_LIBS PUBLIC MLIRAnalysis + MLIRArithmetic MLIRLLVMIR MLIROpenMP MLIRSCF diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -14,6 +14,7 @@ #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h" #include "../PassDetail.h" #include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/SCF.h" @@ -248,27 +249,27 @@ // Match simple binary reductions that can be expressed with atomicrmw. Type type = reduce.operand().getType(); Block &reduction = reduce.getRegion().front(); - if (matchSimpleReduction(reduction)) { + if (matchSimpleReduction(reduction)) { omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, builder.getFloatAttr(type, 0.0)); return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce); } - if (matchSimpleReduction(reduction)) { + if (matchSimpleReduction(reduction)) { omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, builder.getIntegerAttr(type, 0)); return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce); } - if (matchSimpleReduction(reduction)) { + if (matchSimpleReduction(reduction)) { omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, builder.getIntegerAttr(type, 0)); return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce); } - if (matchSimpleReduction(reduction)) { + if (matchSimpleReduction(reduction)) { omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, builder.getIntegerAttr(type, 0)); return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce); } - if (matchSimpleReduction(reduction)) { + if (matchSimpleReduction(reduction)) { omp::ReductionDeclareOp decl = createDecl( builder, symbolTable, reduce, builder.getIntegerAttr( @@ -279,25 +280,25 @@ // Match simple binary reductions that cannot be expressed with atomicrmw. // TODO: add atomic region using cmpxchg (which needs atomic load to be // available as an op). - if (matchSimpleReduction(reduction)) { + if (matchSimpleReduction(reduction)) { return createDecl(builder, symbolTable, reduce, builder.getFloatAttr(type, 1.0)); } // Match select-based min/max reductions. bool isMin; - if (matchSelectReduction( - reduction, {CmpFPredicate::OLT, CmpFPredicate::OLE}, - {CmpFPredicate::OGT, CmpFPredicate::OGE}, isMin) || + if (matchSelectReduction( + reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE}, + {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) || matchSelectReduction( reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole}, {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) { return createDecl(builder, symbolTable, reduce, minMaxValueForFloat(type, !isMin)); } - if (matchSelectReduction( - reduction, {CmpIPredicate::slt, CmpIPredicate::sle}, - {CmpIPredicate::sgt, CmpIPredicate::sge}, isMin) || + if (matchSelectReduction( + reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle}, + {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) || matchSelectReduction( reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle}, {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) { @@ -307,9 +308,9 @@ isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max, decl, reduce); } - if (matchSelectReduction( - reduction, {CmpIPredicate::ult, CmpIPredicate::ule}, - {CmpIPredicate::ugt, CmpIPredicate::uge}, isMin) || + if (matchSelectReduction( + reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule}, + {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) || matchSelectReduction( reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule}, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) { diff --git a/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt --- a/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt @@ -9,6 +9,7 @@ MLIRConversionPassIncGen LINK_LIBS PUBLIC + MLIRArithmeticToSPIRV MLIRMemRefToSPIRV MLIRSPIRV MLIRSPIRVConversion diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h" #include "../PassDetail.h" +#include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h" #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h" #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h" @@ -43,6 +44,7 @@ // TODO: Change SPIR-V conversion to be progressive and remove the following // patterns. + mlir::arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); populateStandardToSPIRVPatterns(typeConverter, patterns); populateMemRefToSPIRVPatterns(typeConverter, patterns); populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns); diff --git a/mlir/lib/Conversion/SCFToStandard/CMakeLists.txt b/mlir/lib/Conversion/SCFToStandard/CMakeLists.txt --- a/mlir/lib/Conversion/SCFToStandard/CMakeLists.txt +++ b/mlir/lib/Conversion/SCFToStandard/CMakeLists.txt @@ -11,6 +11,7 @@ Core LINK_LIBS PUBLIC + MLIRArithmetic MLIRSCF MLIRTransforms ) diff --git a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp --- a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp +++ b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "../PassDetail.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" @@ -314,7 +315,7 @@ Operation *terminator = lastBodyBlock->getTerminator(); rewriter.setInsertionPointToEnd(lastBodyBlock); auto step = forOp.step(); - auto stepped = rewriter.create(loc, iv, step).getResult(); + auto stepped = rewriter.create(loc, iv, step).getResult(); if (!stepped) return failure(); @@ -341,8 +342,8 @@ // With the body block done, we can fill in the condition block. rewriter.setInsertionPointToEnd(conditionBlock); - auto comparison = - rewriter.create(loc, CmpIPredicate::slt, iv, upperBound); + auto comparison = rewriter.create( + loc, arith::CmpIPredicate::slt, iv, upperBound); rewriter.create(loc, comparison, firstBodyBlock, ArrayRef(), endBlock, ArrayRef()); diff --git a/mlir/lib/Conversion/SPIRVCommon/Pattern.h b/mlir/lib/Conversion/SPIRVCommon/Pattern.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/SPIRVCommon/Pattern.h @@ -0,0 +1,45 @@ +//===- Pattern.h - SPIRV Common Conversion Patterns -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_SPIRVCOMMON_PATTERN_H +#define MLIR_CONVERSION_SPIRVCOMMON_PATTERN_H + +#include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace spirv { + +/// Converts unary and binary standard operations to SPIR-V operations. +template +class UnaryAndBinaryOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + assert(adaptor.getOperands().size() <= 2); + auto dstType = this->getTypeConverter()->convertType(op.getType()); + if (!dstType) + return failure(); + if (SPIRVOp::template hasTrait() && + dstType != op.getType()) { + return op.emitError( + "bitwidth emulation is not implemented yet on unsigned op"); + } + rewriter.template replaceOpWithNewOp(op, dstType, + adaptor.getOperands()); + return success(); + } +}; + +} // end namespace spirv +} // end namespace mlir + +#endif // MLIR_CONVERSION_SPIRVCOMMON_PATTERN_H diff --git a/mlir/lib/Conversion/SPIRVToLLVM/CMakeLists.txt b/mlir/lib/Conversion/SPIRVToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/SPIRVToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/SPIRVToLLVM/CMakeLists.txt @@ -11,6 +11,7 @@ intrinsics_gen LINK_LIBS PUBLIC + MLIRArithmeticToLLVM MLIRGPUOps MLIRSPIRV MLIRSPIRVUtils diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "../PassDetail.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" @@ -287,6 +288,8 @@ auto *context = module.getContext(); RewritePatternSet patterns(context); LLVMTypeConverter typeConverter(context, options); + mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, + patterns); populateMemRefToLLVMConversionPatterns(typeConverter, patterns); populateStdToLLVMConversionPatterns(typeConverter, patterns); patterns.add(typeConverter); diff --git a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt --- a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt +++ b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt @@ -17,6 +17,7 @@ Core LINK_LIBS PUBLIC + MLIRArithmetic MLIRIR MLIRShape MLIRTensor diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" #include "../PassDetail.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -75,13 +76,13 @@ // number of extent tensors and shifted offsets into them. Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors, ValueRange rankDiffs, Value outputDimension) { - Value one = lb.create(1); + Value one = lb.create(1); Value broadcastedDim = one; for (auto tup : llvm::zip(extentTensors, rankDiffs)) { Value shape = std::get<0>(tup); Value rankDiff = std::get<1>(tup); - Value outOfBounds = - lb.create(CmpIPredicate::ult, outputDimension, rankDiff); + Value outOfBounds = lb.create(arith::CmpIPredicate::ult, + outputDimension, rankDiff); Type indexTy = lb.getIndexType(); broadcastedDim = lb.create( @@ -97,13 +98,14 @@ // - otherwise, take the extent as-is. // Note that this logic remains correct in the presence // of dimensions of zero extent. - Value lesserRankOperandDimension = - b.create(loc, indexTy, outputDimension, rankDiff); + Value lesserRankOperandDimension = b.create( + loc, indexTy, outputDimension, rankDiff); Value lesserRankOperandExtent = b.create( loc, shape, ValueRange{lesserRankOperandDimension}); - Value dimIsOne = b.create(loc, CmpIPredicate::eq, - lesserRankOperandExtent, one); + Value dimIsOne = + b.create(loc, arith::CmpIPredicate::eq, + lesserRankOperandExtent, one); Value dim = b.create(loc, dimIsOne, broadcastedDim, lesserRankOperandExtent); b.create(loc, dim); @@ -125,7 +127,7 @@ auto loc = op.getLoc(); ImplicitLocOpBuilder lb(loc, rewriter); - Value zero = lb.create(0); + Value zero = lb.create(0); Type indexTy = lb.getIndexType(); // Save all the ranks for bounds checking. Because this is a tensor @@ -139,13 +141,14 @@ // Find the maximum rank Value maxRank = ranks.front(); for (Value v : llvm::drop_begin(ranks, 1)) { - Value rankIsGreater = lb.create(CmpIPredicate::ugt, v, maxRank); + Value rankIsGreater = + lb.create(arith::CmpIPredicate::ugt, v, maxRank); maxRank = lb.create(rankIsGreater, v, maxRank); } // Calculate the difference of ranks and the maximum rank for later offsets. llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { - return lb.create(indexTy, maxRank, v); + return lb.create(indexTy, maxRank, v); })); Value replacement = lb.create( @@ -186,7 +189,7 @@ SmallVector extentOperands; for (auto extent : op.shape()) { extentOperands.push_back( - rewriter.create(loc, extent.getLimitedValue())); + rewriter.create(loc, extent.getLimitedValue())); } Type indexTy = rewriter.getIndexType(); Value tensor = @@ -210,7 +213,8 @@ LogicalResult ConstSizeOpConversion::matchAndRewrite( ConstSizeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - rewriter.replaceOpWithNewOp(op, op.value().getSExtValue()); + rewriter.replaceOpWithNewOp( + op, op.value().getSExtValue()); return success(); } @@ -236,8 +240,8 @@ auto loc = op.getLoc(); ImplicitLocOpBuilder lb(loc, rewriter); - Value zero = lb.create(0); - Value one = lb.create(1); + Value zero = lb.create(0); + Value one = lb.create(1); Type indexTy = lb.getIndexType(); // Save all the ranks for bounds checking. Because this is a tensor @@ -251,18 +255,19 @@ // Find the maximum rank Value maxRank = ranks.front(); for (Value v : llvm::drop_begin(ranks, 1)) { - Value rankIsGreater = lb.create(CmpIPredicate::ugt, v, maxRank); + Value rankIsGreater = + lb.create(arith::CmpIPredicate::ugt, v, maxRank); maxRank = lb.create(rankIsGreater, v, maxRank); } // Calculate the difference of ranks and the maximum rank for later offsets. llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { - return lb.create(indexTy, maxRank, v); + return lb.create(indexTy, maxRank, v); })); Type i1Ty = rewriter.getI1Type(); Value trueVal = - rewriter.create(loc, i1Ty, rewriter.getBoolAttr(true)); + rewriter.create(loc, i1Ty, rewriter.getBoolAttr(true)); auto reduceResult = lb.create( loc, zero, maxRank, one, ValueRange{trueVal}, @@ -277,8 +282,8 @@ for (auto tup : llvm::zip(adaptor.shapes(), rankDiffs)) { Value shape, rankDiff; std::tie(shape, rankDiff) = tup; - Value outOfBounds = - b.create(loc, CmpIPredicate::ult, iv, rankDiff); + Value outOfBounds = b.create( + loc, arith::CmpIPredicate::ult, iv, rankDiff); broadcastable = b.create( loc, TypeRange{i1Ty}, outOfBounds, @@ -290,18 +295,19 @@ // Every value needs to be either 1, or the same non-1 // value to be broadcastable in this dim. Value operandDimension = - b.create(loc, indexTy, iv, rankDiff); + b.create(loc, indexTy, iv, rankDiff); Value dimensionExtent = b.create( loc, shape, ValueRange{operandDimension}); - Value equalOne = b.create(loc, CmpIPredicate::eq, - dimensionExtent, one); - Value equalBroadcasted = - b.create(loc, CmpIPredicate::eq, - dimensionExtent, broadcastedDim); - Value result = b.create( + Value equalOne = b.create( + loc, arith::CmpIPredicate::eq, dimensionExtent, one); + Value equalBroadcasted = b.create( + loc, arith::CmpIPredicate::eq, dimensionExtent, + broadcastedDim); + Value result = b.create( loc, broadcastable, - b.create(loc, equalOne, equalBroadcasted)); + b.create(loc, equalOne, + equalBroadcasted)); b.create(loc, result); }) .getResult(0); @@ -389,8 +395,8 @@ auto loc = op.getLoc(); - Value zero = rewriter.create(loc, 0); - Value one = rewriter.create(loc, 1); + Value zero = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); Type indexTy = rewriter.getIndexType(); Value rank = rewriter.create(loc, indexTy, adaptor.shape(), zero); @@ -433,20 +439,20 @@ /// %c0 = constant 0 : index /// %0 = dim %arg0, %c0 : tensor /// %1 = dim %arg1, %c0 : tensor -/// %2 = cmpi "eq", %0, %1 : index +/// %2 = arith.cmpi "eq", %0, %1 : index /// %result = scf.if %2 -> (i1) { -/// %c1 = constant 1 : index -/// %true = constant true +/// %c1 = arith.constant 1 : index +/// %true = arith.constant true /// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) { /// %5 = tensor.extract %arg0[%arg2] : tensor /// %6 = tensor.extract %arg1[%arg2] : tensor -/// %7 = cmpi "eq", %5, %6 : index -/// %8 = and %arg3, %7 : i1 +/// %7 = arith.cmpi "eq", %5, %6 : index +/// %8 = arith.andi %arg3, %7 : i1 /// scf.yield %8 : i1 /// } /// scf.yield %4 : i1 /// } else { -/// %false = constant false +/// %false = arith.constant false /// scf.yield %false : i1 /// } /// @@ -468,14 +474,14 @@ Type i1Ty = rewriter.getI1Type(); if (op.shapes().size() <= 1) { - rewriter.replaceOpWithNewOp(op, i1Ty, - rewriter.getBoolAttr(true)); + rewriter.replaceOpWithNewOp(op, i1Ty, + rewriter.getBoolAttr(true)); return success(); } auto loc = op.getLoc(); Type indexTy = rewriter.getIndexType(); - Value zero = rewriter.create(loc, 0); + Value zero = rewriter.create(loc, 0); Value firstShape = adaptor.shapes().front(); Value firstRank = rewriter.create(loc, indexTy, firstShape, zero); @@ -483,13 +489,14 @@ // Generate a linear sequence of compares, all with firstShape as lhs. for (Value shape : adaptor.shapes().drop_front(1)) { Value rank = rewriter.create(loc, indexTy, shape, zero); - Value eqRank = - rewriter.create(loc, CmpIPredicate::eq, firstRank, rank); + Value eqRank = rewriter.create(loc, arith::CmpIPredicate::eq, + firstRank, rank); auto same = rewriter.create( loc, i1Ty, eqRank, [&](OpBuilder &b, Location loc) { - Value one = b.create(loc, 1); - Value init = b.create(loc, i1Ty, b.getBoolAttr(true)); + Value one = b.create(loc, 1); + Value init = + b.create(loc, i1Ty, b.getBoolAttr(true)); auto loop = b.create( loc, zero, firstRank, one, ValueRange{init}, [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { @@ -497,19 +504,21 @@ Value lhsExtent = b.create(loc, firstShape, iv); Value rhsExtent = b.create(loc, shape, iv); - Value eqExtent = b.create(loc, CmpIPredicate::eq, - lhsExtent, rhsExtent); - Value conjNext = b.create(loc, conj, eqExtent); + Value eqExtent = b.create( + loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent); + Value conjNext = b.create(loc, conj, eqExtent); b.create(loc, ValueRange({conjNext})); }); b.create(loc, loop.getResults()); }, [&](OpBuilder &b, Location loc) { - Value result = b.create(loc, i1Ty, b.getBoolAttr(false)); + Value result = + b.create(loc, i1Ty, b.getBoolAttr(false)); b.create(loc, result); }); result = !result ? same.getResult(0) - : rewriter.create(loc, result, same.getResult(0)); + : rewriter.create(loc, result, + same.getResult(0)); } rewriter.replaceOp(op, result); return success(); @@ -549,8 +558,8 @@ Value extent = rewriter.create(loc, tensor, i); extentValues.push_back(extent); } else { - Value extent = - rewriter.create(loc, rankedTensorTy.getDimSize(i)); + Value extent = rewriter.create( + loc, rankedTensorTy.getDimSize(i)); extentValues.push_back(extent); } } @@ -598,20 +607,20 @@ return failure(); ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value zero = b.create(0); + Value zero = b.create(0); Value rank = b.create(adaptor.operand(), zero); // index < 0 ? index + rank : index Value originalIndex = adaptor.index(); - Value add = b.create(originalIndex, rank); + Value add = b.create(originalIndex, rank); Value indexIsNegative = - b.create(CmpIPredicate::slt, originalIndex, zero); + b.create(arith::CmpIPredicate::slt, originalIndex, zero); Value index = b.create(indexIsNegative, add, originalIndex); - Value one = b.create(1); + Value one = b.create(1); Value head = b.create(adaptor.operand(), zero, index, one); - Value tailSize = b.create(rank, index); + Value tailSize = b.create(rank, index); Value tail = b.create(adaptor.operand(), index, tailSize, one); rewriter.replaceOp(op, {head, tail}); @@ -655,8 +664,8 @@ // Setup target legality. MLIRContext &ctx = getContext(); ConversionTarget target(ctx); - target - .addLegalDialect(); + target.addLegalDialect(); target.addLegalOp(); // Setup conversion patterns. @@ -675,8 +684,8 @@ populateWithGenerated(patterns); patterns.add< AnyOpConversion, - BinaryOpConversion, - BinaryOpConversion, + BinaryOpConversion, + BinaryOpConversion, BroadcastOpConverter, ConstShapeOpConverter, ConstSizeOpConversion, diff --git a/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt @@ -13,6 +13,7 @@ LINK_LIBS PUBLIC MLIRAnalysis + MLIRArithmeticToLLVM MLIRDataLayoutInterfaces MLIRLLVMCommonConversion MLIRLLVMIR diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -13,6 +13,7 @@ #include "../PassDetail.h" #include "mlir/Analysis/DataLayoutAnalysis.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/VectorPattern.h" @@ -20,14 +21,12 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LogicalResult.h" @@ -390,54 +389,7 @@ }; // Straightforward lowerings. -using AbsFOpLowering = VectorConvertToLLVMPattern; -using AddFOpLowering = VectorConvertToLLVMPattern; -using AddIOpLowering = VectorConvertToLLVMPattern; -using AndOpLowering = VectorConvertToLLVMPattern; -using BitcastOpLowering = - VectorConvertToLLVMPattern; -using CeilFOpLowering = VectorConvertToLLVMPattern; -using CopySignOpLowering = - VectorConvertToLLVMPattern; -using DivFOpLowering = VectorConvertToLLVMPattern; -using FPExtOpLowering = VectorConvertToLLVMPattern; -using FPToSIOpLowering = VectorConvertToLLVMPattern; -using FPToUIOpLowering = VectorConvertToLLVMPattern; -using FPTruncOpLowering = - VectorConvertToLLVMPattern; -using FloorFOpLowering = VectorConvertToLLVMPattern; -using FmaFOpLowering = VectorConvertToLLVMPattern; -using MulFOpLowering = VectorConvertToLLVMPattern; -using MulIOpLowering = VectorConvertToLLVMPattern; -using NegFOpLowering = VectorConvertToLLVMPattern; -using OrOpLowering = VectorConvertToLLVMPattern; -using RemFOpLowering = VectorConvertToLLVMPattern; -using SIToFPOpLowering = VectorConvertToLLVMPattern; using SelectOpLowering = VectorConvertToLLVMPattern; -using SignExtendIOpLowering = - VectorConvertToLLVMPattern; -using ShiftLeftOpLowering = - VectorConvertToLLVMPattern; -using SignedDivIOpLowering = - VectorConvertToLLVMPattern; -using SignedRemIOpLowering = - VectorConvertToLLVMPattern; -using SignedShiftRightOpLowering = - VectorConvertToLLVMPattern; -using SubFOpLowering = VectorConvertToLLVMPattern; -using SubIOpLowering = VectorConvertToLLVMPattern; -using TruncateIOpLowering = - VectorConvertToLLVMPattern; -using UIToFPOpLowering = VectorConvertToLLVMPattern; -using UnsignedDivIOpLowering = - VectorConvertToLLVMPattern; -using UnsignedRemIOpLowering = - VectorConvertToLLVMPattern; -using UnsignedShiftRightOpLowering = - VectorConvertToLLVMPattern; -using XOrOpLowering = VectorConvertToLLVMPattern; -using ZeroExtendIOpLowering = - VectorConvertToLLVMPattern; /// Lower `std.assert`. The default lowering calls the `abort` function if the /// assertion is violated and has no effect otherwise. The failure message is @@ -651,118 +603,6 @@ } }; -// The lowering of index_cast becomes an integer conversion since index becomes -// an integer. If the bit width of the source and target integer types is the -// same, just erase the cast. If the target type is wider, sign-extend the -// value, otherwise truncate it. -struct IndexCastOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(IndexCastOp indexCastOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto targetType = - typeConverter->convertType(indexCastOp.getResult().getType()); - auto targetElementType = - typeConverter - ->convertType(getElementTypeOrSelf(indexCastOp.getResult())) - .cast(); - auto sourceElementType = - getElementTypeOrSelf(adaptor.in()).cast(); - unsigned targetBits = targetElementType.getWidth(); - unsigned sourceBits = sourceElementType.getWidth(); - - if (targetBits == sourceBits) - rewriter.replaceOp(indexCastOp, adaptor.in()); - else if (targetBits < sourceBits) - rewriter.replaceOpWithNewOp(indexCastOp, targetType, - adaptor.in()); - else - rewriter.replaceOpWithNewOp(indexCastOp, targetType, - adaptor.in()); - return success(); - } -}; - -// Convert std.cmp predicate into the LLVM dialect CmpPredicate. The two -// enums share the numerical values so just cast. -template -static LLVMPredType convertCmpPredicate(StdPredType pred) { - return static_cast(pred); -} - -struct CmpIOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(CmpIOp cmpiOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto operandType = adaptor.lhs().getType(); - auto resultType = cmpiOp.getResult().getType(); - - // Handle the scalar and 1D vector cases. - if (!operandType.isa()) { - rewriter.replaceOpWithNewOp( - cmpiOp, typeConverter->convertType(resultType), - convertCmpPredicate(cmpiOp.getPredicate()), - adaptor.lhs(), adaptor.rhs()); - return success(); - } - - auto vectorType = resultType.dyn_cast(); - if (!vectorType) - return rewriter.notifyMatchFailure(cmpiOp, "expected vector result type"); - - return LLVM::detail::handleMultidimensionalVectors( - cmpiOp.getOperation(), adaptor.getOperands(), *getTypeConverter(), - [&](Type llvm1DVectorTy, ValueRange operands) { - CmpIOpAdaptor adaptor(operands); - return rewriter.create( - cmpiOp.getLoc(), llvm1DVectorTy, - convertCmpPredicate(cmpiOp.getPredicate()), - adaptor.lhs(), adaptor.rhs()); - }, - rewriter); - - return success(); - } -}; - -struct CmpFOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(CmpFOp cmpfOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto operandType = adaptor.lhs().getType(); - auto resultType = cmpfOp.getResult().getType(); - - // Handle the scalar and 1D vector cases. - if (!operandType.isa()) { - rewriter.replaceOpWithNewOp( - cmpfOp, typeConverter->convertType(resultType), - convertCmpPredicate(cmpfOp.getPredicate()), - adaptor.lhs(), adaptor.rhs()); - return success(); - } - - auto vectorType = resultType.dyn_cast(); - if (!vectorType) - return rewriter.notifyMatchFailure(cmpfOp, "expected vector result type"); - - return LLVM::detail::handleMultidimensionalVectors( - cmpfOp.getOperation(), adaptor.getOperands(), *getTypeConverter(), - [&](Type llvm1DVectorTy, ValueRange operands) { - CmpFOpAdaptor adaptor(operands); - return rewriter.create( - cmpfOp.getLoc(), llvm1DVectorTy, - convertCmpPredicate(cmpfOp.getPredicate()), - adaptor.lhs(), adaptor.rhs()); - }, - rewriter); - } -}; - // Base class for LLVM IR lowering terminator operations with successors. template struct OneToOneLLVMTerminatorLowering @@ -1131,57 +971,20 @@ populateStdToLLVMFuncOpConversionPattern(converter, patterns); // clang-format off patterns.add< - AbsFOpLowering, - AddFOpLowering, - AddIOpLowering, - AndOpLowering, AssertOpLowering, AtomicRMWOpLowering, - BitcastOpLowering, BranchOpLowering, CallIndirectOpLowering, CallOpLowering, - CeilFOpLowering, - CmpFOpLowering, - CmpIOpLowering, CondBranchOpLowering, - CopySignOpLowering, ConstantOpLowering, - DivFOpLowering, - FloorFOpLowering, - FmaFOpLowering, GenericAtomicRMWOpLowering, - FPExtOpLowering, - FPToSIOpLowering, - FPToUIOpLowering, - FPTruncOpLowering, - IndexCastOpLowering, - MulFOpLowering, - MulIOpLowering, - NegFOpLowering, - OrOpLowering, - RemFOpLowering, RankOpLowering, ReturnOpLowering, - SIToFPOpLowering, SelectOpLowering, - ShiftLeftOpLowering, - SignExtendIOpLowering, - SignedDivIOpLowering, - SignedRemIOpLowering, - SignedShiftRightOpLowering, SplatOpLowering, SplatNdOpLowering, - SubFOpLowering, - SubIOpLowering, - SwitchOpLowering, - TruncateIOpLowering, - UIToFPOpLowering, - UnsignedDivIOpLowering, - UnsignedRemIOpLowering, - UnsignedShiftRightOpLowering, - XOrOpLowering, - ZeroExtendIOpLowering>(converter); + SwitchOpLowering>(converter); // clang-format on } @@ -1231,6 +1034,7 @@ RewritePatternSet patterns(&getContext()); populateStdToLLVMConversionPatterns(typeConverter, patterns); + arith::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns); LLVMConversionTarget target(getContext()); if (failed(applyPartialConversion(m, target, std::move(patterns)))) diff --git a/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt --- a/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt @@ -10,8 +10,9 @@ MLIRConversionPassIncGen LINK_LIBS PUBLIC + MLIRArithmeticToSPIRV MLIRIR - MLIRMath + MLIRMathToSPIRV MLIRMemRef MLIRPass MLIRSPIRV diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "../SPIRVCommon/Pattern.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" @@ -29,15 +30,6 @@ // Utility functions //===----------------------------------------------------------------------===// -/// Returns true if the given `type` is a boolean scalar or vector type. -static bool isBoolScalarOrVector(Type type) { - if (type.isInteger(1)) - return true; - if (auto vecType = type.dyn_cast()) - return vecType.getElementType().isInteger(1); - return false; -} - /// Converts the given `srcAttr` into a boolean attribute if it holds an /// integral value. Returns null attribute if conversion fails. static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) { @@ -98,35 +90,6 @@ return builder.getF32FloatAttr(dstVal.convertToFloat()); } -/// Returns signed remainder for `lhs` and `rhs` and lets the result follow -/// the sign of `signOperand`. -/// -/// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment -/// spec, "for the OpSRem and OpSMod instructions, if either operand is negative -/// the result is undefined." So we cannot directly use spv.SRem/spv.SMod -/// if either operand can be negative. Emulate it via spv.UMod. -static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, - Value signOperand, OpBuilder &builder) { - assert(lhs.getType() == rhs.getType()); - assert(lhs == signOperand || rhs == signOperand); - - Type type = lhs.getType(); - - // Calculate the remainder with spv.UMod. - Value lhsAbs = builder.create(loc, type, lhs); - Value rhsAbs = builder.create(loc, type, rhs); - Value abs = builder.create(loc, lhsAbs, rhsAbs); - - // Fix the sign. - Value isPositive; - if (lhs == signOperand) - isPositive = builder.create(loc, lhs, lhsAbs); - else - isPositive = builder.create(loc, rhs, rhsAbs); - Value absNegate = builder.create(loc, type, abs); - return builder.create(loc, type, isPositive, abs, absNegate); -} - //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// @@ -137,71 +100,6 @@ namespace { -/// Converts unary and binary standard operations to SPIR-V operations. -template -class UnaryAndBinaryOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - assert(adaptor.getOperands().size() <= 2); - auto dstType = this->getTypeConverter()->convertType(operation.getType()); - if (!dstType) - return failure(); - if (SPIRVOp::template hasTrait() && - dstType != operation.getType()) { - return operation.emitError( - "bitwidth emulation is not implemented yet on unsigned op"); - } - rewriter.template replaceOpWithNewOp(operation, dstType, - adaptor.getOperands()); - return success(); - } -}; - -/// Converts std.remi_signed to SPIR-V ops. -/// -/// This cannot be merged into the template unary/binary pattern due to -/// Vulkan restrictions over spv.SRem and spv.SMod. -class SignedRemIOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(SignedRemIOp remOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts bitwise standard operations to SPIR-V operations. This is a special -/// pattern other than the BinaryOpPatternPattern because if the operands are -/// boolean values, SPIR-V uses different operations (`SPIRVLogicalOp`). For -/// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. -template -class BitwiseOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - assert(adaptor.getOperands().size() == 2); - auto dstType = - this->getTypeConverter()->convertType(operation.getResult().getType()); - if (!dstType) - return failure(); - if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) { - rewriter.template replaceOpWithNewOp( - operation, dstType, adaptor.getOperands()); - } else { - rewriter.template replaceOpWithNewOp( - operation, dstType, adaptor.getOperands()); - } - return success(); - } -}; - /// Converts composite std.constant operation to spv.Constant. class ConstantCompositeOpPattern final : public OpConversionPattern { @@ -223,58 +121,6 @@ ConversionPatternRewriter &rewriter) const override; }; -/// Converts floating-point comparison operations to SPIR-V ops. -class CmpFOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(CmpFOp cmpFOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts floating point NaN check to SPIR-V ops. This pattern requires -/// Kernel capability. -class CmpFOpNanKernelPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(CmpFOp cmpFOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts floating point NaN check to SPIR-V ops. This pattern does not -/// require additional capability. -class CmpFOpNanNonePattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(CmpFOp cmpFOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts integer compare operation on i1 type operands to SPIR-V ops. -class BoolCmpIOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(CmpIOp cmpIOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts integer compare operation to SPIR-V ops. -class CmpIOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(CmpIOp cmpIOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - /// Converts std.return to spv.Return. class ReturnOpPattern final : public OpConversionPattern { public: @@ -304,30 +150,6 @@ ConversionPatternRewriter &rewriter) const override; }; -/// Converts std.zexti to spv.Select if the type of source is i1 or vector of -/// i1. -class ZeroExtendI1Pattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(ZeroExtendIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto srcType = adaptor.getOperands().front().getType(); - if (!isBoolScalarOrVector(srcType)) - return failure(); - - auto dstType = - this->getTypeConverter()->convertType(op.getResult().getType()); - Location loc = op.getLoc(); - Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); - Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); - rewriter.template replaceOpWithNewOp( - op, dstType, adaptor.getOperands().front(), one, zero); - return success(); - } -}; - /// Converts tensor.extract into loading using access chains from SPIR-V local /// variables. class TensorExtractPattern final @@ -389,124 +211,8 @@ int64_t byteCountThreshold; }; -/// Converts std.trunci to spv.Select if the type of result is i1 or vector of -/// i1. -class TruncI1Pattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(TruncateIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto dstType = - this->getTypeConverter()->convertType(op.getResult().getType()); - if (!isBoolScalarOrVector(dstType)) - return failure(); - - Location loc = op.getLoc(); - auto srcType = adaptor.getOperands().front().getType(); - // Check if (x & 1) == 1. - Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); - Value maskedSrc = rewriter.create( - loc, srcType, adaptor.getOperands()[0], mask); - Value isOne = rewriter.create(loc, maskedSrc, mask); - - Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); - Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); - rewriter.replaceOpWithNewOp(op, dstType, isOne, one, zero); - return success(); - } -}; - -/// Converts std.uitofp to spv.Select if the type of source is i1 or vector of -/// i1. -class UIToFPI1Pattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(UIToFPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto srcType = adaptor.getOperands().front().getType(); - if (!isBoolScalarOrVector(srcType)) - return failure(); - - auto dstType = - this->getTypeConverter()->convertType(op.getResult().getType()); - Location loc = op.getLoc(); - Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); - Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); - rewriter.template replaceOpWithNewOp( - op, dstType, adaptor.getOperands().front(), one, zero); - return success(); - } -}; - -/// Converts type-casting standard operations to SPIR-V operations. -template -class TypeCastingOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - assert(adaptor.getOperands().size() == 1); - auto srcType = adaptor.getOperands().front().getType(); - auto dstType = - this->getTypeConverter()->convertType(operation.getResult().getType()); - if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) - return failure(); - if (dstType == srcType) { - // Due to type conversion, we are seeing the same source and target type. - // Then we can just erase this operation by forwarding its operand. - rewriter.replaceOp(operation, adaptor.getOperands().front()); - } else { - rewriter.template replaceOpWithNewOp(operation, dstType, - adaptor.getOperands()); - } - return success(); - } -}; - -/// Converts std.xor to SPIR-V operations. -class XOrOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(XOrOp xorOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts std.xor to SPIR-V operations if the type of source is i1 or vector -/// of i1. -class BoolXOrOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(XOrOp xorOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - } // namespace -//===----------------------------------------------------------------------===// -// SignedRemIOpPattern -//===----------------------------------------------------------------------===// - -LogicalResult SignedRemIOpPattern::matchAndRewrite( - SignedRemIOp remOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value result = emulateSignedRemainder( - remOp.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], - adaptor.getOperands()[0], rewriter); - rewriter.replaceOp(remOp, result); - - return success(); -} - //===----------------------------------------------------------------------===// // ConstantOp with composite type. //===----------------------------------------------------------------------===// @@ -649,143 +355,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// CmpFOp -//===----------------------------------------------------------------------===// - -LogicalResult -CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - switch (cmpFOp.getPredicate()) { -#define DISPATCH(cmpPredicate, spirvOp) \ - case cmpPredicate: \ - rewriter.replaceOpWithNewOp(cmpFOp, cmpFOp.getResult().getType(), \ - adaptor.lhs(), adaptor.rhs()); \ - return success(); - - // Ordered. - DISPATCH(CmpFPredicate::OEQ, spirv::FOrdEqualOp); - DISPATCH(CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); - DISPATCH(CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); - DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp); - DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); - DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp); - // Unordered. - DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp); - DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); - DISPATCH(CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); - DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp); - DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); - DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp); - -#undef DISPATCH - - default: - break; - } - return failure(); -} - -LogicalResult CmpFOpNanKernelPattern::matchAndRewrite( - CmpFOp cmpFOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (cmpFOp.getPredicate() == CmpFPredicate::ORD) { - rewriter.replaceOpWithNewOp(cmpFOp, adaptor.lhs(), - adaptor.rhs()); - return success(); - } - - if (cmpFOp.getPredicate() == CmpFPredicate::UNO) { - rewriter.replaceOpWithNewOp(cmpFOp, adaptor.lhs(), - adaptor.rhs()); - return success(); - } - - return failure(); -} - -LogicalResult CmpFOpNanNonePattern::matchAndRewrite( - CmpFOp cmpFOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (cmpFOp.getPredicate() != CmpFPredicate::ORD && - cmpFOp.getPredicate() != CmpFPredicate::UNO) - return failure(); - - Location loc = cmpFOp.getLoc(); - - Value lhsIsNan = rewriter.create(loc, adaptor.lhs()); - Value rhsIsNan = rewriter.create(loc, adaptor.rhs()); - - Value replace = rewriter.create(loc, lhsIsNan, rhsIsNan); - if (cmpFOp.getPredicate() == CmpFPredicate::ORD) - replace = rewriter.create(loc, replace); - - rewriter.replaceOp(cmpFOp, replace); - return success(); -} - -//===----------------------------------------------------------------------===// -// CmpIOp -//===----------------------------------------------------------------------===// - -LogicalResult -BoolCmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Type operandType = cmpIOp.lhs().getType(); - if (!isBoolScalarOrVector(operandType)) - return failure(); - - switch (cmpIOp.getPredicate()) { -#define DISPATCH(cmpPredicate, spirvOp) \ - case cmpPredicate: \ - rewriter.replaceOpWithNewOp(cmpIOp, cmpIOp.getResult().getType(), \ - adaptor.lhs(), adaptor.rhs()); \ - return success(); - - DISPATCH(CmpIPredicate::eq, spirv::LogicalEqualOp); - DISPATCH(CmpIPredicate::ne, spirv::LogicalNotEqualOp); - -#undef DISPATCH - default:; - } - return failure(); -} - -LogicalResult -CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Type operandType = cmpIOp.lhs().getType(); - if (isBoolScalarOrVector(operandType)) - return failure(); - - switch (cmpIOp.getPredicate()) { -#define DISPATCH(cmpPredicate, spirvOp) \ - case cmpPredicate: \ - if (spirvOp::template hasTrait() && \ - operandType != this->getTypeConverter()->convertType(operandType)) { \ - return cmpIOp.emitError( \ - "bitwidth emulation is not implemented yet on unsigned op"); \ - } \ - rewriter.replaceOpWithNewOp(cmpIOp, cmpIOp.getResult().getType(), \ - adaptor.lhs(), adaptor.rhs()); \ - return success(); - - DISPATCH(CmpIPredicate::eq, spirv::IEqualOp); - DISPATCH(CmpIPredicate::ne, spirv::INotEqualOp); - DISPATCH(CmpIPredicate::slt, spirv::SLessThanOp); - DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp); - DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp); - DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp); - DISPATCH(CmpIPredicate::ult, spirv::ULessThanOp); - DISPATCH(CmpIPredicate::ule, spirv::ULessThanEqualOp); - DISPATCH(CmpIPredicate::ugt, spirv::UGreaterThanOp); - DISPATCH(CmpIPredicate::uge, spirv::UGreaterThanEqualOp); - -#undef DISPATCH - } - return failure(); -} - //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// @@ -833,43 +402,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// XorOp -//===----------------------------------------------------------------------===// - -LogicalResult -XOrOpPattern::matchAndRewrite(XOrOp xorOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - assert(adaptor.getOperands().size() == 2); - - if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) - return failure(); - - auto dstType = getTypeConverter()->convertType(xorOp.getType()); - if (!dstType) - return failure(); - rewriter.replaceOpWithNewOp(xorOp, dstType, - adaptor.getOperands()); - - return success(); -} - -LogicalResult -BoolXOrOpPattern::matchAndRewrite(XOrOp xorOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - assert(adaptor.getOperands().size() == 2); - - if (!isBoolScalarOrVector(adaptor.getOperands().front().getType())) - return failure(); - - auto dstType = getTypeConverter()->convertType(xorOp.getType()); - if (!dstType) - return failure(); - rewriter.replaceOpWithNewOp(xorOp, dstType, - adaptor.getOperands()); - return success(); -} - //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// @@ -880,55 +412,12 @@ MLIRContext *context = patterns.getContext(); patterns.add< - // Unary and binary patterns - BitwiseOpPattern, - BitwiseOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - SignedRemIOpPattern, XOrOpPattern, BoolXOrOpPattern, - - // Comparison patterns - BoolCmpIOpPattern, CmpFOpPattern, CmpFOpNanNonePattern, CmpIOpPattern, - // Constant patterns ConstantCompositeOpPattern, ConstantScalarOpPattern, - ReturnOpPattern, SelectOpPattern, SplatPattern, - - // Type cast patterns - UIToFPI1Pattern, ZeroExtendI1Pattern, TruncI1Pattern, - TypeCastingOpPattern, - TypeCastingOpPattern, - TypeCastingOpPattern, - TypeCastingOpPattern, - TypeCastingOpPattern, - TypeCastingOpPattern, - TypeCastingOpPattern, - TypeCastingOpPattern, - TypeCastingOpPattern>(typeConverter, - context); - - // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel - // capability is available. - patterns.add(typeConverter, context, - /*benefit=*/2); + ReturnOpPattern, SelectOpPattern, SplatPattern + + >(typeConverter, context); } void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp @@ -12,6 +12,8 @@ #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h" #include "../PassDetail.h" +#include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" +#include "mlir/Conversion/MathToSPIRV/MathToSPIRV.h" #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" @@ -38,10 +40,13 @@ options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes; SPIRVTypeConverter typeConverter(targetAttr, options); + // TODO ArithmeticToSPIRV cannot be applied separately to StandardToSPIRV RewritePatternSet patterns(context); + arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); + populateMathToSPIRVPatterns(typeConverter, patterns); populateStandardToSPIRVPatterns(typeConverter, patterns); - populateTensorToSPIRVPatterns(typeConverter, - /*byteCountThreshold=*/64, patterns); + populateTensorToSPIRVPatterns(typeConverter, /*byteCountThreshold=*/64, + patterns); populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns); if (failed(applyPartialConversion(module, *target, std::move(patterns)))) diff --git a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt --- a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt +++ b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt @@ -10,6 +10,7 @@ MLIRConversionPassIncGen LINK_LIBS PUBLIC + MLIRArithmetic MLIRDialectUtils MLIRIR MLIRLinalg diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/SCF.h" @@ -32,12 +33,12 @@ } template -static mlir::ConstantOp +static arith::ConstantOp createConstFromIntAttribute(Operation *op, std::string attrName, Type requiredAttrType, OpBuilder &rewriter) { auto castedN = static_cast( op->getAttr(attrName).cast().getValue().getSExtValue()); - return rewriter.create( + return rewriter.create( op->getLoc(), IntegerAttr::get(requiredAttrType, castedN)); } @@ -50,9 +51,9 @@ } template -static mlir::SelectOp clampHelper(Location loc, Value arg, mlir::ConstantOp min, - mlir::ConstantOp max, P pred, - OpBuilder &rewriter) { +static mlir::SelectOp clampHelper(Location loc, Value arg, + arith::ConstantOp min, arith::ConstantOp max, + P pred, OpBuilder &rewriter) { auto smallerThanMin = rewriter.create(loc, pred, arg, min); auto minOrArg = rewriter.create(loc, smallerThanMin, min, arg); @@ -83,7 +84,7 @@ highIndices.push_back(rewriter.getIndexAttr(highPad)); } - Value padValue = rewriter.create(loc, padAttr); + Value padValue = rewriter.create(loc, padAttr); return linalg::PadTensorOp::createPadScalarOp( RankedTensorType::get(paddedShape, inputETy), input, padValue, @@ -101,30 +102,30 @@ // tosa::AbsOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); if (isa(op) && elementTy.isa()) { - auto zero = - rewriter.create(loc, rewriter.getZeroAttr(elementTy)); - auto cmp = - rewriter.create(loc, CmpIPredicate::sgt, args[0], zero); - auto neg = rewriter.create(loc, zero, args[0]); + auto zero = rewriter.create( + loc, rewriter.getZeroAttr(elementTy)); + auto cmp = rewriter.create(loc, arith::CmpIPredicate::sgt, + args[0], zero); + auto neg = rewriter.create(loc, zero, args[0]); return rewriter.create(loc, cmp, args[0], neg); } // tosa::AddOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::SubOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::MulOp if (isa(op) && elementTy.isa()) { @@ -133,18 +134,18 @@ "Cannot have shift value for float"); return nullptr; } - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); } // tosa::DivOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::ReciprocalOp if (isa(op) && elementTy.isa()) { auto one = - rewriter.create(loc, FloatAttr::get(elementTy, 1)); - return rewriter.create(loc, resultTypes, one, args[0]); + rewriter.create(loc, FloatAttr::get(elementTy, 1)); + return rewriter.create(loc, resultTypes, one, args[0]); } if (isa(op) && elementTy.isa()) { @@ -154,12 +155,12 @@ op->getAttr("shift").cast().getValue().getSExtValue(); if (shift > 0) { auto shiftConst = - rewriter.create(loc, shift, /*bitwidth=*/8); + rewriter.create(loc, shift, /*bitwidth=*/8); if (!a.getType().isInteger(32)) - a = rewriter.create(loc, rewriter.getI32Type(), a); + a = rewriter.create(loc, rewriter.getI32Type(), a); if (!b.getType().isInteger(32)) - b = rewriter.create(loc, rewriter.getI32Type(), b); + b = rewriter.create(loc, rewriter.getI32Type(), b); auto result = rewriter.create( loc, rewriter.getI32Type(), a, b, shiftConst, @@ -168,7 +169,7 @@ if (elementTy.isInteger(32)) return result; - return rewriter.create(loc, elementTy, result); + return rewriter.create(loc, elementTy, result); } int aWidth = a.getType().getIntOrFloatBitWidth(); @@ -176,22 +177,22 @@ int cWidth = resultTypes[0].getIntOrFloatBitWidth(); if (aWidth < cWidth) - a = rewriter.create(loc, resultTypes[0], a); + a = rewriter.create(loc, resultTypes[0], a); if (bWidth < cWidth) - b = rewriter.create(loc, resultTypes[0], b); + b = rewriter.create(loc, resultTypes[0], b); - return rewriter.create(loc, resultTypes, a, b); + return rewriter.create(loc, resultTypes, a, b); } // tosa::NegateOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); if (isa(op) && elementTy.isa() && !cast(op).quantization_info()) { auto constant = - rewriter.create(loc, IntegerAttr::get(elementTy, 0)); - return rewriter.create(loc, resultTypes, constant, args[0]); + rewriter.create(loc, IntegerAttr::get(elementTy, 0)); + return rewriter.create(loc, resultTypes, constant, args[0]); } if (isa(op) && elementTy.isa() && @@ -220,62 +221,59 @@ } Type intermediateType = rewriter.getIntegerType(intermediateBitWidth); - Value zpAddValue = rewriter.create( + Value zpAddValue = rewriter.create( loc, rewriter.getIntegerAttr(intermediateType, zpAdd)); // The negation can be applied by doing: // outputValue = inZp + outZp - inputValue - auto ext = rewriter.create(loc, intermediateType, args[0]); - auto sub = rewriter.create(loc, zpAddValue, ext); + auto ext = rewriter.create(loc, intermediateType, args[0]); + auto sub = rewriter.create(loc, zpAddValue, ext); // Clamp to the negation range. - auto min = rewriter.create( - loc, rewriter.getIntegerAttr( - intermediateType, - APInt::getSignedMinValue(inputBitWidth).getSExtValue())); - auto max = rewriter.create( - loc, rewriter.getIntegerAttr( - intermediateType, - APInt::getSignedMaxValue(inputBitWidth).getSExtValue())); - auto clamp = clampHelper(loc, sub, min, max, - CmpIPredicate::slt, rewriter); + auto min = rewriter.create( + loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(), + intermediateType); + auto max = rewriter.create( + loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(), + intermediateType); + auto clamp = clampHelper( + loc, sub, min, max, arith::CmpIPredicate::slt, rewriter); // Truncate to the final value. - return rewriter.create(loc, elementTy, clamp); + return rewriter.create(loc, elementTy, clamp); } // tosa::BitwiseAndOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::BitwiseOrOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::BitwiseNotOp if (isa(op) && elementTy.isa()) { auto allOnesAttr = rewriter.getIntegerAttr( elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth())); - auto allOnes = rewriter.create(loc, allOnesAttr); - return rewriter.create(loc, resultTypes, args[0], allOnes); + auto allOnes = rewriter.create(loc, allOnesAttr); + return rewriter.create(loc, resultTypes, args[0], allOnes); } // tosa::BitwiseXOrOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::LogicalLeftShiftOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::LogicalRightShiftOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::ArithmeticRightShiftOp if (isa(op) && elementTy.isa()) { - auto result = - rewriter.create(loc, resultTypes, args); + auto result = rewriter.create(loc, resultTypes, args); auto round = op->getAttr("round").cast().getValue(); if (!round) { return result; @@ -283,40 +281,40 @@ Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1); auto one = - rewriter.create(loc, IntegerAttr::get(elementTy, 1)); + rewriter.create(loc, IntegerAttr::get(elementTy, 1)); auto zero = - rewriter.create(loc, IntegerAttr::get(elementTy, 0)); + rewriter.create(loc, IntegerAttr::get(elementTy, 0)); auto i1one = - rewriter.create(loc, IntegerAttr::get(i1Ty, 1)); + rewriter.create(loc, IntegerAttr::get(i1Ty, 1)); // Checking that input2 != 0 - auto shiftValueGreaterThanZero = - rewriter.create(loc, CmpIPredicate::sgt, args[1], zero); + auto shiftValueGreaterThanZero = rewriter.create( + loc, arith::CmpIPredicate::sgt, args[1], zero); // Checking for the last bit of input1 to be 1 auto subtract = - rewriter.create(loc, resultTypes, args[1], one); - auto shifted = rewriter - .create(loc, resultTypes, - args[0], subtract) - ->getResults(); + rewriter.create(loc, resultTypes, args[1], one); + auto shifted = + rewriter.create(loc, resultTypes, args[0], subtract) + ->getResults(); auto truncated = - rewriter.create(loc, i1Ty, shifted, mlir::None); - auto isInputOdd = rewriter.create(loc, i1Ty, truncated, i1one); + rewriter.create(loc, i1Ty, shifted, mlir::None); + auto isInputOdd = + rewriter.create(loc, i1Ty, truncated, i1one); - auto shouldRound = rewriter.create( + auto shouldRound = rewriter.create( loc, i1Ty, shiftValueGreaterThanZero, isInputOdd); auto extended = - rewriter.create(loc, resultTypes, shouldRound); - return rewriter.create(loc, resultTypes, result, extended); + rewriter.create(loc, resultTypes, shouldRound); + return rewriter.create(loc, resultTypes, result, extended); } // tosa::ClzOp if (isa(op) && elementTy.isa()) { int bitWidth = elementTy.getIntOrFloatBitWidth(); auto zero = - rewriter.create(loc, IntegerAttr::get(elementTy, 0)); - auto leadingZeros = rewriter.create( + rewriter.create(loc, IntegerAttr::get(elementTy, 0)); + auto leadingZeros = rewriter.create( loc, IntegerAttr::get(elementTy, bitWidth)); SmallVector operands = {args[0], leadingZeros, zero}; @@ -332,8 +330,8 @@ Value input = before->getArgument(0); Value zero = before->getArgument(2); - Value inputLargerThanZero = - rewriter.create(loc, CmpIPredicate::ne, input, zero); + Value inputLargerThanZero = rewriter.create( + loc, arith::CmpIPredicate::ne, input, zero); rewriter.create(loc, inputLargerThanZero, before->getArguments()); } @@ -344,12 +342,12 @@ Value input = after->getArgument(0); Value leadingZeros = after->getArgument(1); - auto one = rewriter.create( + auto one = rewriter.create( loc, IntegerAttr::get(elementTy, 1)); - auto shifted = rewriter.create( - loc, resultTypes, input, one); + auto shifted = + rewriter.create(loc, resultTypes, input, one); auto leadingZerosMinusOne = - rewriter.create(loc, resultTypes, leadingZeros, one); + rewriter.create(loc, resultTypes, leadingZeros, one); rewriter.create( loc, @@ -362,22 +360,22 @@ // tosa::LogicalAnd if (isa(op) && elementTy.isInteger(1)) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::LogicalNot if (isa(op) && elementTy.isInteger(1)) { - auto one = rewriter.create( + auto one = rewriter.create( loc, rewriter.getIntegerAttr(elementTy, 1)); - return rewriter.create(loc, resultTypes, args[0], one); + return rewriter.create(loc, resultTypes, args[0], one); } // tosa::LogicalOr if (isa(op) && elementTy.isInteger(1)) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::LogicalXor if (isa(op) && elementTy.isInteger(1)) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::PowOp if (isa(op) && elementTy.isa()) @@ -401,30 +399,30 @@ // tosa::GreaterOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, CmpFPredicate::OGT, args[0], - args[1]); + return rewriter.create(loc, arith::CmpFPredicate::OGT, + args[0], args[1]); if (isa(op) && elementTy.isSignlessInteger()) - return rewriter.create(loc, CmpIPredicate::sgt, args[0], - args[1]); + return rewriter.create(loc, arith::CmpIPredicate::sgt, + args[0], args[1]); // tosa::GreaterEqualOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, CmpFPredicate::OGE, args[0], - args[1]); + return rewriter.create(loc, arith::CmpFPredicate::OGE, + args[0], args[1]); if (isa(op) && elementTy.isSignlessInteger()) - return rewriter.create(loc, CmpIPredicate::sge, args[0], - args[1]); + return rewriter.create(loc, arith::CmpIPredicate::sge, + args[0], args[1]); // tosa::EqualOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, CmpFPredicate::OEQ, args[0], - args[1]); + return rewriter.create(loc, arith::CmpFPredicate::OEQ, + args[0], args[1]); if (isa(op) && elementTy.isSignlessInteger()) - return rewriter.create(loc, CmpIPredicate::eq, args[0], - args[1]); + return rewriter.create(loc, arith::CmpIPredicate::eq, + args[0], args[1]); // tosa::SelectOp if (isa(op)) { @@ -435,46 +433,46 @@ // tosa::MaximumOp if (isa(op) && elementTy.isa()) { - auto predicate = rewriter.create(loc, CmpFPredicate::OGT, - args[0], args[1]); + auto predicate = rewriter.create( + loc, arith::CmpFPredicate::OGT, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isSignlessInteger()) { - auto predicate = rewriter.create(loc, CmpIPredicate::sgt, - args[0], args[1]); + auto predicate = rewriter.create( + loc, arith::CmpIPredicate::sgt, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } // tosa::MinimumOp if (isa(op) && elementTy.isa()) { - auto predicate = rewriter.create(loc, CmpFPredicate::OLT, - args[0], args[1]); + auto predicate = rewriter.create( + loc, arith::CmpFPredicate::OLT, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isSignlessInteger()) { - auto predicate = rewriter.create(loc, CmpIPredicate::slt, - args[0], args[1]); + auto predicate = rewriter.create( + loc, arith::CmpIPredicate::slt, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } // tosa::CeilOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::FloorOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::ClampOp if (isa(op) && elementTy.isa()) { - auto min = rewriter.create(loc, elementTy, - op->getAttr("min_fp")); - auto max = rewriter.create(loc, elementTy, - op->getAttr("max_fp")); - return clampHelper(loc, args[0], min, max, CmpFPredicate::OLT, - rewriter); + auto min = rewriter.create(loc, elementTy, + op->getAttr("min_fp")); + auto max = rewriter.create(loc, elementTy, + op->getAttr("max_fp")); + return clampHelper(loc, args[0], min, max, + arith::CmpFPredicate::OLT, rewriter); } if (isa(op) && elementTy.isa()) { @@ -498,41 +496,41 @@ .getSExtValue()); } - auto minVal = - rewriter.create(loc, min, intTy.getIntOrFloatBitWidth()); - auto maxVal = - rewriter.create(loc, max, intTy.getIntOrFloatBitWidth()); - return clampHelper(loc, args[0], minVal, maxVal, - CmpIPredicate::slt, rewriter); + auto minVal = rewriter.create( + loc, min, intTy.getIntOrFloatBitWidth()); + auto maxVal = rewriter.create( + loc, max, intTy.getIntOrFloatBitWidth()); + return clampHelper(loc, args[0], minVal, maxVal, + arith::CmpIPredicate::slt, rewriter); } // tosa::ReluNOp if (isa(op) && elementTy.isa()) { auto zero = - rewriter.create(loc, FloatAttr::get(elementTy, 0)); - auto n = rewriter.create(loc, elementTy, - op->getAttr("max_fp")); - return clampHelper(loc, args[0], zero, n, CmpFPredicate::OLT, - rewriter); + rewriter.create(loc, FloatAttr::get(elementTy, 0)); + auto n = rewriter.create(loc, elementTy, + op->getAttr("max_fp")); + return clampHelper(loc, args[0], zero, n, + arith::CmpFPredicate::OLT, rewriter); } if (isa(op) && elementTy.isa()) { auto zero = - rewriter.create(loc, IntegerAttr::get(elementTy, 0)); + rewriter.create(loc, IntegerAttr::get(elementTy, 0)); auto n = createConstFromIntAttribute(op, "max_int", elementTy, rewriter); - return clampHelper(loc, args[0], zero, n, CmpIPredicate::slt, - rewriter); + return clampHelper(loc, args[0], zero, n, + arith::CmpIPredicate::slt, rewriter); } // tosa::SigmoidOp if (isa(op) && elementTy.isa()) { auto one = - rewriter.create(loc, FloatAttr::get(elementTy, 1)); - auto negate = rewriter.create(loc, resultTypes, args[0]); + rewriter.create(loc, FloatAttr::get(elementTy, 1)); + auto negate = rewriter.create(loc, resultTypes, args[0]); auto exp = rewriter.create(loc, resultTypes, negate); - auto added = rewriter.create(loc, resultTypes, exp, one); - return rewriter.create(loc, resultTypes, one, added); + auto added = rewriter.create(loc, resultTypes, exp, one); + return rewriter.create(loc, resultTypes, one, added); } // tosa::CastOp @@ -546,92 +544,92 @@ return args.front(); if (srcTy.isa() && dstTy.isa() && bitExtend) - return rewriter.create(loc, resultTypes, args, mlir::None); + return rewriter.create(loc, resultTypes, args, mlir::None); if (srcTy.isa() && dstTy.isa() && !bitExtend) - return rewriter.create(loc, resultTypes, args, + return rewriter.create(loc, resultTypes, args, mlir::None); // 1-bit integers need to be treated as signless. - if (srcTy.isInteger(1) && mlir::UIToFPOp::areCastCompatible(srcTy, dstTy)) - return rewriter.create(loc, resultTypes, args, - mlir::None); + if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy)) + return rewriter.create(loc, resultTypes, args, + mlir::None); if (srcTy.isInteger(1) && dstTy.isa() && bitExtend) - return rewriter.create(loc, resultTypes, args, - mlir::None); + return rewriter.create(loc, resultTypes, args, + mlir::None); // All other si-to-fp conversions should be handled by SIToFP. - if (mlir::SIToFPOp::areCastCompatible(srcTy, dstTy)) - return rewriter.create(loc, resultTypes, args, - mlir::None); + if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy)) + return rewriter.create(loc, resultTypes, args, + mlir::None); // Casting to boolean, floats need to only be checked as not-equal to zero. if (srcTy.isa() && dstTy.isInteger(1)) { - Value zero = - rewriter.create(loc, rewriter.getFloatAttr(srcTy, 0.0)); - return rewriter.create(loc, CmpFPredicate::UNE, - args.front(), zero); + Value zero = rewriter.create( + loc, rewriter.getFloatAttr(srcTy, 0.0)); + return rewriter.create(loc, arith::CmpFPredicate::UNE, + args.front(), zero); } - if (mlir::FPToSIOp::areCastCompatible(srcTy, dstTy)) { - auto zero = - rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); - auto half = - rewriter.create(loc, rewriter.getF32FloatAttr(0.5f)); + if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) { + auto zero = rewriter.create( + loc, rewriter.getF32FloatAttr(0.0f)); + auto half = rewriter.create( + loc, rewriter.getF32FloatAttr(0.5f)); - auto intMin = rewriter.create( + auto intMin = rewriter.create( loc, rewriter.getF32FloatAttr( APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue())); - auto intMax = rewriter.create( + auto intMax = rewriter.create( loc, rewriter.getF32FloatAttr( APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue())); - auto added = rewriter.create(loc, args[0], half); - auto subbed = rewriter.create(loc, args[0], half); - auto negative = - rewriter.create(loc, CmpFPredicate::OLT, args[0], zero); + auto added = rewriter.create(loc, args[0], half); + auto subbed = rewriter.create(loc, args[0], half); + auto negative = rewriter.create( + loc, arith::CmpFPredicate::OLT, args[0], zero); auto rounded = rewriter.create(loc, negative, subbed, added); - auto clamped = clampHelper(loc, rounded, intMin, intMax, - CmpFPredicate::OLT, rewriter); + auto clamped = clampHelper( + loc, rounded, intMin, intMax, arith::CmpFPredicate::OLT, rewriter); - return rewriter.create(loc, dstTy, clamped); + return rewriter.create(loc, dstTy, clamped); } // Casting to boolean, integers need to only be checked as not-equal to // zero. if (srcTy.isa() && dstTy.isInteger(1)) { - Value zero = - rewriter.create(loc, 0, srcTy.getIntOrFloatBitWidth()); - return rewriter.create(loc, CmpIPredicate::ne, args.front(), - zero); + Value zero = rewriter.create( + loc, 0, srcTy.getIntOrFloatBitWidth()); + return rewriter.create(loc, arith::CmpIPredicate::ne, + args.front(), zero); } if (srcTy.isa() && dstTy.isa() && bitExtend) - return rewriter.create(loc, resultTypes, args, - mlir::None); + return rewriter.create(loc, resultTypes, args, + mlir::None); if (srcTy.isa() && dstTy.isa() && !bitExtend) { - auto intMin = rewriter.create( + auto intMin = rewriter.create( loc, APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue(), srcTy.getIntOrFloatBitWidth()); - auto intMax = rewriter.create( + auto intMax = rewriter.create( loc, APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue(), srcTy.getIntOrFloatBitWidth()); - auto clamped = clampHelper(loc, args[0], intMin, intMax, - CmpIPredicate::slt, rewriter); - return rewriter.create(loc, dstTy, clamped); + auto clamped = clampHelper( + loc, args[0], intMin, intMax, arith::CmpIPredicate::slt, rewriter); + return rewriter.create(loc, dstTy, clamped); } } @@ -814,50 +812,50 @@ PatternRewriter &rewriter) { Location loc = op->getLoc(); if (isa(op) && elementTy.isa()) { - return rewriter.create(loc, args); + return rewriter.create(loc, args); } if (isa(op) && elementTy.isa()) { - return rewriter.create(loc, args); + return rewriter.create(loc, args); } if (isa(op) && elementTy.isa()) { - return rewriter.create(loc, args); + return rewriter.create(loc, args); } if (isa(op) && elementTy.isa()) { - return rewriter.create(loc, args); + return rewriter.create(loc, args); } if (isa(op) && elementTy.isa()) { - auto predicate = rewriter.create(loc, CmpFPredicate::OLT, - args[0], args[1]); + auto predicate = rewriter.create( + loc, arith::CmpFPredicate::OLT, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isa()) { - auto predicate = rewriter.create(loc, CmpIPredicate::slt, - args[0], args[1]); + auto predicate = rewriter.create( + loc, arith::CmpIPredicate::slt, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isa()) { - auto predicate = rewriter.create(loc, CmpFPredicate::OGT, - args[0], args[1]); + auto predicate = rewriter.create( + loc, arith::CmpFPredicate::OGT, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isa()) { - auto predicate = rewriter.create(loc, CmpIPredicate::sgt, - args[0], args[1]); + auto predicate = rewriter.create( + loc, arith::CmpIPredicate::sgt, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isInteger(1)) - return rewriter.create(loc, args); + return rewriter.create(loc, args); if (isa(op) && elementTy.isInteger(1)) - return rewriter.create(loc, args); + return rewriter.create(loc, args); return {}; } @@ -893,7 +891,7 @@ return rewriter.notifyMatchFailure( op, "No initial value found for reduction operation"); - auto fillValue = rewriter.create(loc, fillValueAttr); + auto fillValue = rewriter.create(loc, fillValueAttr); auto filledTensor = rewriter.create(loc, fillValue, initTensor).result(); @@ -1014,7 +1012,8 @@ weightShape[3], weightShape[0]}; auto weightPermAttr = DenseIntElementsAttr::get( RankedTensorType::get({4}, rewriter.getI64Type()), weightPerm); - Value weightPermValue = rewriter.create(loc, weightPermAttr); + Value weightPermValue = + rewriter.create(loc, weightPermAttr); Type newWeightTy = RankedTensorType::get(newWeightShape, weightTy.getElementType()); weight = rewriter.create(loc, newWeightTy, weight, @@ -1023,7 +1022,7 @@ Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy); Value initTensor = rewriter.create( loc, resultTy.getShape(), resultETy); - Value zero = rewriter.create(loc, resultZeroAttr); + Value zero = rewriter.create(loc, resultZeroAttr); Value zeroTensor = rewriter.create(loc, zero, initTensor).getResult(0); @@ -1057,8 +1056,8 @@ auto kZp = rewriter.getI32IntegerAttr( quantizationInfo.weight_zp().getValue().getSExtValue()); - auto iZpVal = rewriter.create(loc, iZp); - auto kZpVal = rewriter.create(loc, kZp); + auto iZpVal = rewriter.create(loc, iZp); + auto kZpVal = rewriter.create(loc, kZp); Value conv = rewriter .create( @@ -1073,8 +1072,8 @@ indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - Value added = - nestedBuilder.create(loc, args[0], args[1]); + Value added = nestedBuilder.create( + loc, args[0], args[1]); nestedBuilder.create(nestedLoc, added); }) .getResult(0); @@ -1095,8 +1094,8 @@ indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - Value added = - nestedBuilder.create(loc, args[0], args[1]); + Value added = nestedBuilder.create( + loc, args[0], args[1]); nestedBuilder.create(nestedLoc, added); }) .getResult(0); @@ -1205,7 +1204,7 @@ Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy); Value initTensor = rewriter.create( loc, linalgConvTy.getShape(), resultETy); - Value zero = rewriter.create(loc, resultZeroAttr); + Value zero = rewriter.create(loc, resultZeroAttr); Value zeroTensor = rewriter.create(loc, zero, initTensor).getResult(0); @@ -1226,15 +1225,15 @@ getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - Value added = - nestedBuilder.create(loc, args[0], args[1]); + Value added = nestedBuilder.create( + loc, args[0], args[1]); nestedBuilder.create(nestedLoc, added); }) .getResult(0); rewriter.replaceOp(op, result); } else { - auto iZpVal = rewriter.create(loc, iZp); - auto kZpVal = rewriter.create(loc, kZp); + auto iZpVal = rewriter.create(loc, iZp); + auto kZpVal = rewriter.create(loc, kZp); Value conv = rewriter .create( @@ -1250,8 +1249,8 @@ getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - Value added = - nestedBuilder.create(loc, args[0], args[1]); + Value added = nestedBuilder.create( + loc, args[0], args[1]); nestedBuilder.create(nestedLoc, added); }) .getResult(0); @@ -1343,7 +1342,7 @@ auto outputTy = op.getType().cast(); auto outputElementTy = outputTy.getElementType(); auto zeroAttr = rewriter.getZeroAttr(outputElementTy); - Value zero = rewriter.create(loc, zeroAttr); + Value zero = rewriter.create(loc, zeroAttr); auto initTensor = rewriter.create( loc, outputTy.getShape(), outputTy.getElementType()); Value zeroTensor = @@ -1356,10 +1355,10 @@ } auto quantizationInfo = op.quantization_info().getValue(); - auto aZp = rewriter.create( + auto aZp = rewriter.create( loc, rewriter.getI32IntegerAttr( quantizationInfo.a_zp().getValue().getSExtValue())); - auto bZp = rewriter.create( + auto bZp = rewriter.create( loc, rewriter.getI32IntegerAttr( quantizationInfo.b_zp().getValue().getSExtValue())); rewriter.replaceOpWithNewOp( @@ -1404,14 +1403,15 @@ // When quantized, the input elemeny type is not the same as the output Attribute resultZeroAttr = rewriter.getZeroAttr(outputETy); - Value zero = rewriter.create(loc, resultZeroAttr); + Value zero = rewriter.create(loc, resultZeroAttr); Value zeroTensor = rewriter.create(loc, zero, initTensor).getResult(0); SmallVector permutation{1, 0}; auto permutationAttr = DenseIntElementsAttr::get( RankedTensorType::get({2}, rewriter.getI64Type()), permutation); - Value permutationValue = rewriter.create(loc, permutationAttr); + Value permutationValue = + rewriter.create(loc, permutationAttr); SmallVector newWeightShape{weightShape[1], weightShape[0]}; Type newWeightTy = @@ -1439,8 +1439,8 @@ indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - Value added = - nestedBuilder.create(loc, args[0], args[1]); + Value added = nestedBuilder.create( + loc, args[0], args[1]); nestedBuilder.create(nestedLoc, added); }) .getResult(0); @@ -1449,10 +1449,10 @@ } auto quantizationInfo = op.quantization_info().getValue(); - auto inputZp = rewriter.create( + auto inputZp = rewriter.create( loc, rewriter.getI32IntegerAttr( quantizationInfo.input_zp().getValue().getSExtValue())); - auto outputZp = rewriter.create( + auto outputZp = rewriter.create( loc, rewriter.getI32IntegerAttr( quantizationInfo.weight_zp().getValue().getSExtValue())); Value matmul = @@ -1469,8 +1469,8 @@ indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - Value added = - nestedBuilder.create(loc, args[0], args[1]); + Value added = nestedBuilder.create( + loc, args[0], args[1]); nestedBuilder.create(nestedLoc, added); }) .getResult(0); @@ -1674,7 +1674,7 @@ Value multiplierConstant; int64_t multiplierArg = 0; if (multiplierValues.size() == 1) { - multiplierConstant = rewriter.create( + multiplierConstant = rewriter.create( loc, rewriter.getI32IntegerAttr(multiplierValues.front())); } else { SmallVector multiplierExprs{ @@ -1682,7 +1682,7 @@ auto multiplierType = RankedTensorType::get({static_cast(multiplierValues.size())}, rewriter.getI32Type()); - genericInputs.push_back(rewriter.create( + genericInputs.push_back(rewriter.create( loc, DenseIntElementsAttr::get(multiplierType, multiplierValues))); indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, @@ -1697,7 +1697,7 @@ Value shiftConstant; int64_t shiftArg = 0; if (shiftValues.size() == 1) { - shiftConstant = rewriter.create( + shiftConstant = rewriter.create( loc, rewriter.getI8IntegerAttr(shiftValues.front())); } else { SmallVector shiftExprs = { @@ -1705,7 +1705,7 @@ auto shiftType = RankedTensorType::get({static_cast(shiftValues.size())}, rewriter.getIntegerType(8)); - genericInputs.push_back(rewriter.create( + genericInputs.push_back(rewriter.create( loc, DenseIntElementsAttr::get(shiftType, shiftValues))); indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, shiftExprs, @@ -1753,22 +1753,24 @@ valueTy.getIntOrFloatBitWidth()), value) .getResult(0); - value = nestedBuilder.create( + value = nestedBuilder.create( nestedLoc, nestedBuilder.getI32Type(), value); } else { - value = nestedBuilder.create( + value = nestedBuilder.create( nestedLoc, nestedBuilder.getI32Type(), value); } } - value = nestedBuilder.create(nestedLoc, value, inputZp); + value = + nestedBuilder.create(nestedLoc, value, inputZp); value = nestedBuilder.create( loc, nestedBuilder.getI32Type(), value, multiplier, shift, nestedBuilder.getBoolAttr(doubleRound)); // Move to the new zero-point. - value = nestedBuilder.create(nestedLoc, value, outputZp); + value = + nestedBuilder.create(nestedLoc, value, outputZp); // Saturate to the output size. IntegerType outIntType = @@ -1784,19 +1786,17 @@ intMax = APInt::getMaxValue(outBitWidth).getZExtValue(); } - auto intMinVal = nestedBuilder.create( - loc, - nestedBuilder.getIntegerAttr(nestedBuilder.getI32Type(), intMin)); - auto intMaxVal = nestedBuilder.create( - loc, - nestedBuilder.getIntegerAttr(nestedBuilder.getI32Type(), intMax)); + auto intMinVal = nestedBuilder.create( + loc, nestedBuilder.getI32IntegerAttr(intMin)); + auto intMaxVal = nestedBuilder.create( + loc, nestedBuilder.getI32IntegerAttr(intMax)); - value = - clampHelper(nestedLoc, value, intMinVal, intMaxVal, - CmpIPredicate::slt, nestedBuilder); + value = clampHelper( + nestedLoc, value, intMinVal, intMaxVal, arith::CmpIPredicate::slt, + nestedBuilder); if (outIntType.getWidth() < 32) { - value = nestedBuilder.create( + value = nestedBuilder.create( nestedLoc, rewriter.getIntegerType(outIntType.getWidth()), value); @@ -1859,37 +1859,39 @@ Value x = rewriter.create(loc, 2); Value channel = rewriter.create(loc, 3); - auto hwMin = - rewriter.create(loc, rewriter.getI32IntegerAttr(0)); - auto hMax = rewriter.create( + auto hwMin = rewriter.create( + loc, rewriter.getI32IntegerAttr(0)); + auto hMax = rewriter.create( loc, rewriter.getI32IntegerAttr(imageH - 1)); - auto wMax = rewriter.create( + auto wMax = rewriter.create( loc, rewriter.getI32IntegerAttr(imageW - 1)); - Value inY = rewriter.create(loc, rewriter.getI32Type(), y); - Value inX = rewriter.create(loc, rewriter.getI32Type(), x); + Value inY = + rewriter.create(loc, rewriter.getI32Type(), y); + Value inX = + rewriter.create(loc, rewriter.getI32Type(), x); int32_t shift = op.shift(); bool floatingPointMode = shift == 0; Value yStride, xStride, yOffset, xOffset; if (floatingPointMode) { - yStride = rewriter.create(loc, op.stride_fp()[0]); - xStride = rewriter.create(loc, op.stride_fp()[1]); - yOffset = rewriter.create(loc, op.offset_fp()[0]); - xOffset = rewriter.create(loc, op.offset_fp()[1]); + yStride = rewriter.create(loc, op.stride_fp()[0]); + xStride = rewriter.create(loc, op.stride_fp()[1]); + yOffset = rewriter.create(loc, op.offset_fp()[0]); + xOffset = rewriter.create(loc, op.offset_fp()[1]); } else { SmallVector stride, offset; getValuesFromIntArrayAttribute(op.stride(), stride); getValuesFromIntArrayAttribute(op.offset(), offset); - yStride = rewriter.create( + yStride = rewriter.create( loc, rewriter.getI32IntegerAttr(stride[0])); - xStride = rewriter.create( + xStride = rewriter.create( loc, rewriter.getI32IntegerAttr(stride[1])); - yOffset = rewriter.create( + yOffset = rewriter.create( loc, rewriter.getI32IntegerAttr(offset[0])); - xOffset = rewriter.create( + xOffset = rewriter.create( loc, rewriter.getI32IntegerAttr(offset[1])); } @@ -1899,85 +1901,89 @@ // dx = x - ix Value ix, iy, dx, dy; if (floatingPointMode) { - Value y = rewriter.create(loc, rewriter.getF32Type(), inY); - Value x = rewriter.create(loc, rewriter.getF32Type(), inX); + Value y = + rewriter.create(loc, rewriter.getF32Type(), inY); + Value x = + rewriter.create(loc, rewriter.getF32Type(), inX); - y = rewriter.create(loc, y, yStride); - x = rewriter.create(loc, x, xStride); + y = rewriter.create(loc, y, yStride); + x = rewriter.create(loc, x, xStride); - y = rewriter.create(loc, y, yOffset); - x = rewriter.create(loc, x, xOffset); + y = rewriter.create(loc, y, yOffset); + x = rewriter.create(loc, x, xOffset); - iy = rewriter.create(loc, y); - ix = rewriter.create(loc, x); + iy = rewriter.create(loc, y); + ix = rewriter.create(loc, x); - dy = rewriter.create(loc, y, iy); - dx = rewriter.create(loc, x, ix); + dy = rewriter.create(loc, y, iy); + dx = rewriter.create(loc, x, ix); - iy = rewriter.create(loc, rewriter.getI32Type(), iy); - ix = rewriter.create(loc, rewriter.getI32Type(), ix); + iy = rewriter.create(loc, rewriter.getI32Type(), iy); + ix = rewriter.create(loc, rewriter.getI32Type(), ix); } else { - Value shiftVal = - rewriter.create(loc, rewriter.getI32IntegerAttr(shift)); + Value shiftVal = rewriter.create( + loc, rewriter.getI32IntegerAttr(shift)); - Value y = rewriter.create(loc, inY, yStride); - Value x = rewriter.create(loc, inX, xStride); + Value y = rewriter.create(loc, inY, yStride); + Value x = rewriter.create(loc, inX, xStride); - y = rewriter.create(loc, y, yOffset); - x = rewriter.create(loc, x, xOffset); + y = rewriter.create(loc, y, yOffset); + x = rewriter.create(loc, x, xOffset); - iy = rewriter.create(loc, y, shiftVal); - ix = rewriter.create(loc, x, shiftVal); + iy = rewriter.create(loc, y, shiftVal); + ix = rewriter.create(loc, x, shiftVal); - Value yTrunc = rewriter.create(loc, iy, shiftVal); - Value xTrunc = rewriter.create(loc, ix, shiftVal); + Value yTrunc = rewriter.create(loc, iy, shiftVal); + Value xTrunc = rewriter.create(loc, ix, shiftVal); - dy = rewriter.create(loc, y, yTrunc); - dx = rewriter.create(loc, x, xTrunc); + dy = rewriter.create(loc, y, yTrunc); + dx = rewriter.create(loc, x, xTrunc); } if (op.mode() == "NEAREST_NEIGHBOR") { Value yPred, xPred; // Round the index position towards the closest pixel location. if (floatingPointMode) { - auto halfVal = - rewriter.create(loc, rewriter.getF32FloatAttr(0.5f)); - yPred = rewriter.create(loc, CmpFPredicate::OGE, dy, - halfVal); - xPred = rewriter.create(loc, CmpFPredicate::OGE, dx, - halfVal); + auto halfVal = rewriter.create( + loc, rewriter.getF32FloatAttr(0.5f)); + yPred = rewriter.create(loc, arith::CmpFPredicate::OGE, + dy, halfVal); + xPred = rewriter.create(loc, arith::CmpFPredicate::OGE, + dx, halfVal); } else { - auto halfVal = rewriter.create( + auto halfVal = rewriter.create( loc, rewriter.getI32IntegerAttr(1 << (shift - 1))); - yPred = rewriter.create(loc, CmpIPredicate::sge, dy, - halfVal); - xPred = rewriter.create(loc, CmpIPredicate::sge, dx, - halfVal); + yPred = rewriter.create(loc, arith::CmpIPredicate::sge, + dy, halfVal); + xPred = rewriter.create(loc, arith::CmpIPredicate::sge, + dx, halfVal); } - auto zeroVal = - rewriter.create(loc, rewriter.getI32IntegerAttr(0)); - auto oneVal = - rewriter.create(loc, rewriter.getI32IntegerAttr(1)); + auto zeroVal = rewriter.create( + loc, rewriter.getI32IntegerAttr(0)); + auto oneVal = rewriter.create( + loc, rewriter.getI32IntegerAttr(1)); auto yOffset = rewriter.create(loc, yPred, oneVal, zeroVal); auto xOffset = rewriter.create(loc, xPred, oneVal, zeroVal); - iy = rewriter.create(loc, iy, yOffset); - ix = rewriter.create(loc, ix, xOffset); + iy = rewriter.create(loc, iy, yOffset); + ix = rewriter.create(loc, ix, xOffset); // Clamp the to be within the bounds of the input image. - iy = clampHelper(loc, iy, hwMin, hMax, CmpIPredicate::slt, - rewriter); - ix = clampHelper(loc, ix, hwMin, wMax, CmpIPredicate::slt, - rewriter); + iy = clampHelper(loc, iy, hwMin, hMax, + arith::CmpIPredicate::slt, rewriter); + ix = clampHelper(loc, ix, hwMin, wMax, + arith::CmpIPredicate::slt, rewriter); // Read the value from the input array. - iy = rewriter.create(loc, rewriter.getIndexType(), iy); - ix = rewriter.create(loc, rewriter.getIndexType(), ix); + iy = rewriter.create(loc, rewriter.getIndexType(), + iy); + ix = rewriter.create(loc, rewriter.getIndexType(), + ix); Value result = rewriter.create( loc, input, ValueRange{batch, iy, ix, channel}); @@ -1991,25 +1997,29 @@ Value y0 = iy; Value x0 = ix; - auto oneVal = - rewriter.create(loc, rewriter.getI32IntegerAttr(1)); - Value y1 = rewriter.create(loc, y0, oneVal); - Value x1 = rewriter.create(loc, x0, oneVal); - - y0 = clampHelper(loc, y0, hwMin, hMax, CmpIPredicate::slt, - rewriter); - y1 = clampHelper(loc, y1, hwMin, hMax, CmpIPredicate::slt, - rewriter); - - x0 = clampHelper(loc, x0, hwMin, wMax, CmpIPredicate::slt, - rewriter); - x1 = clampHelper(loc, x1, hwMin, wMax, CmpIPredicate::slt, - rewriter); - - y0 = rewriter.create(loc, rewriter.getIndexType(), y0); - y1 = rewriter.create(loc, rewriter.getIndexType(), y1); - x0 = rewriter.create(loc, rewriter.getIndexType(), x0); - x1 = rewriter.create(loc, rewriter.getIndexType(), x1); + auto oneVal = rewriter.create( + loc, rewriter.getI32IntegerAttr(1)); + Value y1 = rewriter.create(loc, y0, oneVal); + Value x1 = rewriter.create(loc, x0, oneVal); + + y0 = clampHelper(loc, y0, hwMin, hMax, + arith::CmpIPredicate::slt, rewriter); + y1 = clampHelper(loc, y1, hwMin, hMax, + arith::CmpIPredicate::slt, rewriter); + + x0 = clampHelper(loc, x0, hwMin, wMax, + arith::CmpIPredicate::slt, rewriter); + x1 = clampHelper(loc, x1, hwMin, wMax, + arith::CmpIPredicate::slt, rewriter); + + y0 = rewriter.create(loc, rewriter.getIndexType(), + y0); + y1 = rewriter.create(loc, rewriter.getIndexType(), + y1); + x0 = rewriter.create(loc, rewriter.getIndexType(), + x0); + x1 = rewriter.create(loc, rewriter.getIndexType(), + x1); Value y0x0 = rewriter.create( loc, input, ValueRange{batch, y0, x0, channel}); @@ -2021,56 +2031,58 @@ loc, input, ValueRange{batch, y1, x1, channel}); if (floatingPointMode) { - auto oneVal = - rewriter.create(loc, rewriter.getF32FloatAttr(1.f)); + auto oneVal = rewriter.create( + loc, rewriter.getF32FloatAttr(1.f)); Value rightPart = dx; - Value leftPart = rewriter.create(loc, oneVal, dx); + Value leftPart = rewriter.create(loc, oneVal, dx); - y0x0 = rewriter.create(loc, y0x0, leftPart); - y0x1 = rewriter.create(loc, y0x1, rightPart); - Value topAcc = rewriter.create(loc, y0x0, y0x1); + y0x0 = rewriter.create(loc, y0x0, leftPart); + y0x1 = rewriter.create(loc, y0x1, rightPart); + Value topAcc = rewriter.create(loc, y0x0, y0x1); - y1x0 = rewriter.create(loc, y1x0, leftPart); - y1x1 = rewriter.create(loc, y1x1, rightPart); - Value bottomAcc = rewriter.create(loc, y1x0, y1x1); + y1x0 = rewriter.create(loc, y1x0, leftPart); + y1x1 = rewriter.create(loc, y1x1, rightPart); + Value bottomAcc = rewriter.create(loc, y1x0, y1x1); Value bottomPart = dy; - Value topPart = rewriter.create(loc, oneVal, dy); - topAcc = rewriter.create(loc, topAcc, topPart); - bottomAcc = rewriter.create(loc, bottomAcc, bottomPart); - Value result = rewriter.create(loc, topAcc, bottomAcc); + Value topPart = rewriter.create(loc, oneVal, dy); + topAcc = rewriter.create(loc, topAcc, topPart); + bottomAcc = + rewriter.create(loc, bottomAcc, bottomPart); + Value result = rewriter.create(loc, topAcc, bottomAcc); rewriter.create(loc, result); return success(); } else { - y0x0 = rewriter.create(loc, resultElementTy, y0x0); - y0x1 = rewriter.create(loc, resultElementTy, y0x1); - y1x0 = rewriter.create(loc, resultElementTy, y1x0); - y1x1 = rewriter.create(loc, resultElementTy, y1x1); + y0x0 = rewriter.create(loc, resultElementTy, y0x0); + y0x1 = rewriter.create(loc, resultElementTy, y0x1); + y1x0 = rewriter.create(loc, resultElementTy, y1x0); + y1x1 = rewriter.create(loc, resultElementTy, y1x1); if (resultElementTy.getIntOrFloatBitWidth() > 32) { - dx = rewriter.create(loc, resultElementTy, dx); - dy = rewriter.create(loc, resultElementTy, dy); + dx = rewriter.create(loc, resultElementTy, dx); + dy = rewriter.create(loc, resultElementTy, dy); } - auto unitVal = rewriter.create( + auto unitVal = rewriter.create( loc, rewriter.getIntegerAttr(resultElementTy, 1 << shift)); Value rightPart = dx; - Value leftPart = rewriter.create(loc, unitVal, dx); + Value leftPart = rewriter.create(loc, unitVal, dx); - y0x0 = rewriter.create(loc, y0x0, leftPart); - y0x1 = rewriter.create(loc, y0x1, rightPart); - Value topAcc = rewriter.create(loc, y0x0, y0x1); + y0x0 = rewriter.create(loc, y0x0, leftPart); + y0x1 = rewriter.create(loc, y0x1, rightPart); + Value topAcc = rewriter.create(loc, y0x0, y0x1); - y1x0 = rewriter.create(loc, y1x0, leftPart); - y1x1 = rewriter.create(loc, y1x1, rightPart); - Value bottomAcc = rewriter.create(loc, y1x0, y1x1); + y1x0 = rewriter.create(loc, y1x0, leftPart); + y1x1 = rewriter.create(loc, y1x1, rightPart); + Value bottomAcc = rewriter.create(loc, y1x0, y1x1); Value bottomPart = dy; - Value topPart = rewriter.create(loc, unitVal, dy); - topAcc = rewriter.create(loc, topAcc, topPart); - bottomAcc = rewriter.create(loc, bottomAcc, bottomPart); - Value result = rewriter.create(loc, topAcc, bottomAcc); + Value topPart = rewriter.create(loc, unitVal, dy); + topAcc = rewriter.create(loc, topAcc, topPart); + bottomAcc = + rewriter.create(loc, bottomAcc, bottomPart); + Value result = rewriter.create(loc, topAcc, bottomAcc); rewriter.create(loc, result); return success(); @@ -2125,12 +2137,12 @@ Location loc = op.getLoc(); int axis = op.axis(); Value axisValue = - rewriter.create(loc, rewriter.getIndexAttr(axis)); + rewriter.create(loc, rewriter.getIndexAttr(axis)); int rank = resultType.getRank(); SmallVector offsets, sizes, strides; sizes.reserve(rank); - strides.resize(rank, rewriter.create(loc, 1)); - offsets.resize(rank, rewriter.create(loc, 0)); + strides.resize(rank, rewriter.create(loc, 1)); + offsets.resize(rank, rewriter.create(loc, 0)); for (int i = 0; i < rank; ++i) { sizes.push_back( @@ -2140,14 +2152,14 @@ Value resultDimSize = sizes[axis]; for (auto arg : adaptor.getOperands().drop_front()) { auto size = rewriter.create(loc, arg, axisValue); - resultDimSize = rewriter.create(loc, resultDimSize, size); + resultDimSize = rewriter.create(loc, resultDimSize, size); } sizes[axis] = resultDimSize; Value init = rewriter.create( loc, resultType.getShape(), resultType.getElementType()); - Value zeroVal = rewriter.create( + Value zeroVal = rewriter.create( loc, rewriter.getZeroAttr(resultType.getElementType())); Value result = rewriter.create(loc, zeroVal, init).getResult(0); @@ -2156,7 +2168,8 @@ sizes[axis] = rewriter.create(loc, arg, axisValue); result = rewriter.create(loc, arg, result, offsets, sizes, strides); - offsets[axis] = rewriter.create(loc, offsets[axis], sizes[axis]); + offsets[axis] = + rewriter.create(loc, offsets[axis], sizes[axis]); } rewriter.replaceOp(op, result); return success(); @@ -2202,10 +2215,11 @@ auto index = rewriter.create(nestedLoc, i).getResult(); if (i == axis) { - auto one = rewriter.create(nestedLoc, 1); + auto one = rewriter.create(nestedLoc, 1); auto sizeMinusOne = - rewriter.create(nestedLoc, axisDimSize, one); - index = rewriter.create(nestedLoc, sizeMinusOne, index); + rewriter.create(nestedLoc, axisDimSize, one); + index = rewriter.create(nestedLoc, sizeMinusOne, + index); } indices.push_back(index); @@ -2319,9 +2333,10 @@ "tosa.pad to linalg lowering encountered an unknown element type"); } - Value lowIndex = rewriter.create(loc, rewriter.getIndexAttr(0)); + Value lowIndex = + rewriter.create(loc, rewriter.getIndexAttr(0)); Value highIndex = - rewriter.create(loc, rewriter.getIndexAttr(1)); + rewriter.create(loc, rewriter.getIndexAttr(1)); SmallVector lowValues; SmallVector highValues; @@ -2330,22 +2345,22 @@ highValues.reserve(rank); for (int i = 0; i < rank; i++) { - Value inputIndex = rewriter.createOrFold(loc, i); + Value inputIndex = rewriter.createOrFold(loc, i); Value lowVal = rewriter.createOrFold( loc, padding, ValueRange({inputIndex, lowIndex})); Value highVal = rewriter.createOrFold( loc, padding, ValueRange({inputIndex, highIndex})); - lowVal = rewriter.createOrFold(loc, rewriter.getIndexType(), - lowVal); - highVal = rewriter.createOrFold(loc, rewriter.getIndexType(), - highVal); + lowVal = rewriter.createOrFold( + loc, rewriter.getIndexType(), lowVal); + highVal = rewriter.createOrFold( + loc, rewriter.getIndexType(), highVal); lowValues.push_back(lowVal); highValues.push_back(highVal); } - Value constant = rewriter.create(loc, constantAttr); + Value constant = rewriter.create(loc, constantAttr); auto newPadOp = linalg::PadTensorOp::createPadScalarOp( padOp.getType(), input, constant, lowValues, highValues, @@ -2400,7 +2415,7 @@ .create(loc, ArrayRef({}), resultTy.getShape(), outElementTy) .result(); - auto fillValueIdx = rewriter.create( + auto fillValueIdx = rewriter.create( loc, rewriter.getIntegerAttr(outElementTy, 0)); auto filledTensorIdx = rewriter.create(loc, fillValueIdx, initTensorIdx) @@ -2419,7 +2434,8 @@ return rewriter.notifyMatchFailure( argmaxOp, "unsupported tosa.argmax element type"); - auto fillValueMax = rewriter.create(loc, fillValueMaxAttr); + auto fillValueMax = + rewriter.create(loc, fillValueMaxAttr); auto filledTensorMax = rewriter.create(loc, fillValueMax, initTensorMax) .result(); @@ -2449,17 +2465,17 @@ auto oldIndex = blockArgs[1]; auto oldValue = blockArgs[2]; - Value newIndex = rewriter.create( + Value newIndex = rewriter.create( nestedLoc, oldIndex.getType(), rewriter.create(loc, axis)); Value predicate; if (inElementTy.isa()) { - predicate = rewriter.create( - nestedLoc, CmpFPredicate::OGT, newValue, oldValue); + predicate = rewriter.create( + nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue); } else if (inElementTy.isa()) { - predicate = rewriter.create( - nestedLoc, CmpIPredicate::sgt, newValue, oldValue); + predicate = rewriter.create( + nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue); } else { didEncounterError = true; return; @@ -2523,7 +2539,7 @@ [&](OpBuilder &b, Location loc, ValueRange args) { auto indexValue = args[0]; auto index0 = rewriter.create(loc, 0); - Value index1 = rewriter.create( + Value index1 = rewriter.create( loc, rewriter.getIndexType(), indexValue); auto index2 = rewriter.create(loc, 2); Value extract = rewriter.create( @@ -2584,11 +2600,11 @@ rewriter.setInsertionPointToStart(block); if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) && resultElementTy.isInteger(8)) { - Value index = rewriter.create(loc, rewriter.getIndexType(), - inputValue); - Value offset = rewriter.create(loc, 128); - index = rewriter.create(loc, rewriter.getIndexType(), index, - offset); + Value index = rewriter.create( + loc, rewriter.getIndexType(), inputValue); + Value offset = rewriter.create(loc, 128); + index = rewriter.create(loc, rewriter.getIndexType(), + index, offset); Value extract = rewriter.create(loc, table, ValueRange{index}); rewriter.create(loc, extract); @@ -2597,35 +2613,35 @@ if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) && resultElementTy.isInteger(32)) { - Value extend = rewriter.create( + Value extend = rewriter.create( loc, rewriter.getI32Type(), inputValue); - auto offset = - rewriter.create(loc, rewriter.getI32IntegerAttr(32768)); - auto seven = - rewriter.create(loc, rewriter.getI32IntegerAttr(7)); - auto one = - rewriter.create(loc, rewriter.getI32IntegerAttr(1)); - auto b1111111 = - rewriter.create(loc, rewriter.getI32IntegerAttr(127)); + auto offset = rewriter.create( + loc, rewriter.getI32IntegerAttr(32768)); + auto seven = rewriter.create( + loc, rewriter.getI32IntegerAttr(7)); + auto one = rewriter.create( + loc, rewriter.getI32IntegerAttr(1)); + auto b1111111 = rewriter.create( + loc, rewriter.getI32IntegerAttr(127)); // Compute the index and fractional part from the input value: // value = value + 32768 // index = value >> 7; // fraction = 0x01111111 & value - auto extendAdd = rewriter.create(loc, extend, offset); - Value index = - rewriter.create(loc, extendAdd, seven); - Value fraction = rewriter.create(loc, extendAdd, b1111111); + auto extendAdd = rewriter.create(loc, extend, offset); + Value index = rewriter.create(loc, extendAdd, seven); + Value fraction = + rewriter.create(loc, extendAdd, b1111111); // Extract the base and next values from the table. // base = (int32_t) table[index]; // next = (int32_t) table[index + 1]; - Value indexPlusOne = rewriter.create(loc, index, one); + Value indexPlusOne = rewriter.create(loc, index, one); - index = - rewriter.create(loc, rewriter.getIndexType(), index); - indexPlusOne = rewriter.create( + index = rewriter.create( + loc, rewriter.getIndexType(), index); + indexPlusOne = rewriter.create( loc, rewriter.getIndexType(), indexPlusOne); Value base = @@ -2633,15 +2649,18 @@ Value next = rewriter.create( loc, table, ValueRange{indexPlusOne}); - base = rewriter.create(loc, rewriter.getI32Type(), base); - next = rewriter.create(loc, rewriter.getI32Type(), next); + base = + rewriter.create(loc, rewriter.getI32Type(), base); + next = + rewriter.create(loc, rewriter.getI32Type(), next); // Use the fractional part to interpolate between the input values: // result = (base << 7) + (next - base) * fraction - Value baseScaled = rewriter.create(loc, base, seven); - Value diff = rewriter.create(loc, next, base); - Value diffScaled = rewriter.create(loc, diff, fraction); - Value result = rewriter.create(loc, baseScaled, diffScaled); + Value baseScaled = rewriter.create(loc, base, seven); + Value diff = rewriter.create(loc, next, base); + Value diffScaled = rewriter.create(loc, diff, fraction); + Value result = + rewriter.create(loc, baseScaled, diffScaled); rewriter.create(loc, result); @@ -2694,7 +2713,7 @@ pad.resize(pad.size() + 2, 0); Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter); - Value initialValue = rewriter.create(loc, initialAttr); + Value initialValue = rewriter.create(loc, initialAttr); SmallVector kernel, stride; getValuesFromIntArrayAttribute(op.kernel(), kernel); @@ -2749,7 +2768,7 @@ Attribute initialAttr = rewriter.getZeroAttr(accETy); Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter); - Value initialValue = rewriter.create(loc, initialAttr); + Value initialValue = rewriter.create(loc, initialAttr); SmallVector kernel, stride; getValuesFromIntArrayAttribute(op.kernel(), kernel); @@ -2791,18 +2810,18 @@ ArrayRef({affineMap, affineMap}), getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &b, Location loc, ValueRange args) { - auto zero = rewriter.create(loc, 0); - auto one = rewriter.create(loc, 1); - auto iH = rewriter.create( + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + auto iH = rewriter.create( loc, poolingOpTy.getDimSize(1) - 1); - auto iW = rewriter.create( + auto iW = rewriter.create( loc, poolingOpTy.getDimSize(2) - 1); // Compute the indices from either end. auto y0 = rewriter.create(loc, 1); auto x0 = rewriter.create(loc, 2); - auto y1 = rewriter.create(loc, iH, y0); - auto x1 = rewriter.create(loc, iW, x0); + auto y1 = rewriter.create(loc, iH, y0); + auto x1 = rewriter.create(loc, iW, x0); // Determines what the portion of valid input is covered by the // kernel. @@ -2810,34 +2829,34 @@ if (pad == 0) return v; - auto padVal = rewriter.create(loc, pad); - Value dx = rewriter.create(loc, x, padVal); + auto padVal = rewriter.create(loc, pad); + Value dx = rewriter.create(loc, x, padVal); - Value cmp = rewriter.create(loc, CmpIPredicate::slt, - dx, zero); + Value cmp = rewriter.create( + loc, arith::CmpIPredicate::slt, dx, zero); Value offset = rewriter.create(loc, cmp, dx, zero); - return rewriter.create(loc, v, offset)->getResult(0); + return rewriter.create(loc, v, offset)->getResult(0); }; // Compute the vertical component of coverage. - auto kH0 = rewriter.create(loc, kernel[0]); + auto kH0 = rewriter.create(loc, kernel[0]); auto kH1 = padFn(kH0, y0, pad[2]); auto kH2 = padFn(kH1, y1, pad[3]); - auto kHCmp = - rewriter.create(loc, CmpIPredicate::slt, kH2, one); + auto kHCmp = rewriter.create( + loc, arith::CmpIPredicate::slt, kH2, one); auto kH3 = rewriter.create(loc, kHCmp, one, kH2); // compute the horizontal component of coverage. - auto kW0 = rewriter.create(loc, kernel[1]); + auto kW0 = rewriter.create(loc, kernel[1]); auto kW1 = padFn(kW0, x0, pad[4]); auto kW2 = padFn(kW1, x1, pad[5]); - auto kWCmp = - rewriter.create(loc, CmpIPredicate::slt, kW2, one); + auto kWCmp = rewriter.create( + loc, arith::CmpIPredicate::slt, kW2, one); auto kW3 = rewriter.create(loc, kWCmp, one, kW2); // Compute the total number of elements and normalize. - Value count = rewriter.create(loc, kH3, kW3); - auto countI = rewriter.create( + Value count = rewriter.create(loc, kH3, kW3); + auto countI = rewriter.create( loc, rewriter.getI32Type(), count); // Divide by the number of summed values. For floats this is just @@ -2846,20 +2865,21 @@ Value poolVal = args[0]; if (accETy.isa()) { auto countF = - rewriter.create(loc, inElementTy, countI); - poolVal = - rewriter.create(loc, poolVal, countF)->getResult(0); + rewriter.create(loc, inElementTy, countI); + poolVal = rewriter.create(loc, poolVal, countF) + ->getResult(0); } else { // If we have quantization information we need to apply an offset // for the input zp value. if (op.quantization_info()) { auto quantizationInfo = op.quantization_info().getValue(); - auto inputZp = rewriter.create( + auto inputZp = rewriter.create( loc, quantizationInfo.input_zp()); Value offset = - rewriter.create(loc, accETy, countI, inputZp); - poolVal = rewriter.create(loc, accETy, poolVal, offset); + rewriter.create(loc, accETy, countI, inputZp); + poolVal = + rewriter.create(loc, accETy, poolVal, offset); } // Compute the multiplier and shift values for the quantization @@ -2869,14 +2889,14 @@ int64_t numerator = ((1 << 30) + 1); int64_t shift = 30; - Value numeratorVal = rewriter.create( + Value numeratorVal = rewriter.create( loc, rewriter.getI32IntegerAttr(numerator)); Value multiplierVal = rewriter - .create(loc, rewriter.getI32Type(), + .create(loc, rewriter.getI32Type(), numeratorVal, countI) .getResult(); - Value shiftVal = rewriter.create( + Value shiftVal = rewriter.create( loc, rewriter.getI8IntegerAttr(shift)); auto scaled = @@ -2890,28 +2910,26 @@ // zeropoint. if (op.quantization_info()) { auto quantizationInfo = op.quantization_info().getValue(); - auto outputZp = rewriter.create( + auto outputZp = rewriter.create( loc, quantizationInfo.output_zp()); - scaled = - rewriter.create(loc, scaled, outputZp).getResult(); + scaled = rewriter.create(loc, scaled, outputZp) + .getResult(); } // Apply Clip. int64_t outBitwidth = resultETy.getIntOrFloatBitWidth(); - auto min = rewriter.create( - loc, rewriter.getIntegerAttr( - accETy, - APInt::getSignedMinValue(outBitwidth).getSExtValue())); - auto max = rewriter.create( - loc, rewriter.getIntegerAttr( - accETy, - APInt::getSignedMaxValue(outBitwidth).getSExtValue())); - auto clamp = clampHelper( - loc, scaled, min, max, CmpIPredicate::slt, rewriter); + auto min = rewriter.create( + loc, APInt::getSignedMinValue(outBitwidth).getSExtValue(), + accETy); + auto max = rewriter.create( + loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(), + accETy); + auto clamp = clampHelper( + loc, scaled, min, max, arith::CmpIPredicate::slt, rewriter); // Convert type. - poolVal = rewriter.create(loc, resultETy, clamp); + poolVal = rewriter.create(loc, resultETy, clamp); } // Cast to output type. diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -12,6 +12,7 @@ #include "../PassDetail.h" #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/SCF.h" @@ -33,9 +34,9 @@ : public TosaToLinalgOnTensorsBase { public: void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert(); + registry.insert(); } void runOnFunction() override { diff --git a/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt b/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt --- a/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt +++ b/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt @@ -10,6 +10,7 @@ MLIRConversionPassIncGen LINK_LIBS PUBLIC + MLIRArithmetic MLIRIR MLIRStandard MLIRPass diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp --- a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp +++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/TosaToStandard/TosaToStandard.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" @@ -28,7 +29,7 @@ LogicalResult matchAndRewrite(tosa::ConstOp op, PatternRewriter &rewriter) const final { - rewriter.replaceOpWithNewOp<::ConstantOp>(op, op.value()); + rewriter.replaceOpWithNewOp(op, op.value()); return success(); } }; @@ -67,12 +68,12 @@ bool doubleRound = op.double_round(); Type inType = op.value().getType(); - Value one8 = rewriter.create( + Value one8 = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getIntegerType(8), 1)); - Value one64 = rewriter.create( + Value one64 = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1)); - Value shiftSubOne8 = rewriter.create(loc, shift8, one8); + Value shiftSubOne8 = rewriter.create(loc, shift8, one8); // The rounding value semantics below equate to the following code: // int64_t round = 1 << (shift - 1); @@ -83,45 +84,45 @@ // // Note that minimal bitwidth operators are used throughout the block. - Value round64 = rewriter.create( + Value round64 = rewriter.create( loc, one64, - rewriter.create(loc, rewriter.getI64Type(), - shiftSubOne8)); + rewriter.create(loc, rewriter.getI64Type(), + shiftSubOne8)); // Double rounding is performing a round operation before the shift if (doubleRound) { - Value one32 = rewriter.create( + Value one32 = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 1)); - Value shift32 = rewriter.create( - loc, rewriter.getI32Type(), shift8); - Value thirty32 = rewriter.create( + Value shift32 = + rewriter.create(loc, rewriter.getI32Type(), shift8); + Value thirty32 = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 30)); Value shiftThirty32 = - rewriter.create(loc, one32, thirty32); - Value shiftThirty64 = rewriter.create( + rewriter.create(loc, one32, thirty32); + Value shiftThirty64 = rewriter.create( loc, rewriter.getI64Type(), shiftThirty32); // Round value needs to with be added or subtracted depending on the sign // of the input value. Value roundAdd64 = - rewriter.create(loc, round64, shiftThirty64); + rewriter.create(loc, round64, shiftThirty64); Value roundSub64 = - rewriter.create(loc, round64, shiftThirty64); + rewriter.create(loc, round64, shiftThirty64); Value zero32 = - rewriter.create(loc, rewriter.getZeroAttr(inType)); - Value valueGreaterThanZero = rewriter.create( - loc, CmpIPredicate::sge, value32, zero32); + rewriter.create(loc, rewriter.getZeroAttr(inType)); + Value valueGreaterThanZero = rewriter.create( + loc, arith::CmpIPredicate::sge, value32, zero32); Value doubleRound64 = rewriter.create( loc, valueGreaterThanZero, roundAdd64, roundSub64); // We only perform double rounding if the shift value is greater than 32. - Value thirtyTwo32 = rewriter.create( + Value thirtyTwo32 = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 32)); - Value shiftGreaterThanThirtyTwo = rewriter.create( - loc, CmpIPredicate::sge, shift32, thirtyTwo32); + Value shiftGreaterThanThirtyTwo = rewriter.create( + loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32); round64 = rewriter.create(loc, shiftGreaterThanThirtyTwo, doubleRound64, round64); } @@ -133,20 +134,19 @@ // Note that multiply and shift need to be perform in i64 to preserve bits. Value value64 = - rewriter.create(loc, rewriter.getI64Type(), value32); - Value multiplier64 = rewriter.create( + rewriter.create(loc, rewriter.getI64Type(), value32); + Value multiplier64 = rewriter.create( loc, rewriter.getI64Type(), multiplier32); Value shift64 = - rewriter.create(loc, rewriter.getI64Type(), shift8); + rewriter.create(loc, rewriter.getI64Type(), shift8); // Multiply as a pair of i64 values to guarantee the end value fits. - Value result64 = rewriter.create(loc, value64, multiplier64); - result64 = rewriter.create(loc, result64, round64); - result64 = - rewriter.create(loc, result64, shift64); + Value result64 = rewriter.create(loc, value64, multiplier64); + result64 = rewriter.create(loc, result64, round64); + result64 = rewriter.create(loc, result64, shift64); - Value result32 = rewriter.create( - loc, rewriter.getI32Type(), result64); + Value result32 = + rewriter.create(loc, rewriter.getI32Type(), result64); rewriter.replaceOp(op, result32); return success(); diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp --- a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp +++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp @@ -12,6 +12,7 @@ #include "../PassDetail.h" #include "mlir/Conversion/TosaToStandard/TosaToStandard.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" @@ -34,6 +35,7 @@ target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); diff --git a/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt b/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt @@ -8,6 +8,7 @@ Core LINK_LIBS PUBLIC + MLIRArithmetic MLIRGPUOps MLIRLLVMIR MLIRMemRef diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -16,6 +16,7 @@ #include "../PassDetail.h" #include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" @@ -116,7 +117,7 @@ /// Return true if the constant is a splat to a 2D vector so that it can be /// converted to a MMA constant matrix op. -static bool constantSupportsMMAMatrixType(ConstantOp constantOp) { +static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) { auto vecType = constantOp.getType().dyn_cast(); if (!vecType || vecType.getRank() != 2) return false; @@ -138,7 +139,7 @@ return transferWriteSupportsMMAMatrixType(transferWrite); if (auto contract = dyn_cast(op)) return contractSupportsMMAMatrixType(contract); - if (auto constant = dyn_cast(op)) + if (auto constant = dyn_cast(op)) return constantSupportsMMAMatrixType(constant); if (auto broadcast = dyn_cast(op)) return broadcastSupportsMMAMatrixType(broadcast); @@ -324,13 +325,13 @@ } /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. -static void convertConstantOp(ConstantOp op, +static void convertConstantOp(arith::ConstantOp op, llvm::DenseMap &valueMapping) { assert(constantSupportsMMAMatrixType(op)); OpBuilder b(op); - Attribute splat = op.getValue().cast().getSplatValue(); + Attribute splat = op.value().cast().getSplatValue(); auto scalarConstant = - b.create(op.getLoc(), splat.getType(), splat); + b.create(op.getLoc(), splat.getType(), splat); const char *fragType = inferFragType(op); auto vecType = op.getType().cast(); gpu::MMAMatrixType type = gpu::MMAMatrixType::get( @@ -439,7 +440,7 @@ convertTransferWriteOp(transferWrite, valueMapping); } else if (auto contractOp = dyn_cast(op)) { convertContractOp(contractOp, valueMapping); - } else if (auto constantOp = dyn_cast(op)) { + } else if (auto constantOp = dyn_cast(op)) { convertConstantOp(constantOp, valueMapping); } else if (auto broadcastOp = dyn_cast(op)) { convertBroadcastOp(broadcastOp, valueMapping); diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -13,6 +13,7 @@ Core LINK_LIBS PUBLIC + MLIRArithmetic MLIRArmNeon MLIRArmSVE MLIRArmSVETransforms diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -59,7 +60,7 @@ return rewriter.create(loc, from, into, offset); return rewriter.create( loc, vectorType, from, into, - rewriter.create(loc, offset)); + rewriter.create(loc, offset)); } // Helper that picks the proper sequence for extracting. @@ -86,7 +87,7 @@ return rewriter.create(loc, vector, offset); return rewriter.create( loc, vectorType.getElementType(), vector, - rewriter.create(loc, offset)); + rewriter.create(loc, offset)); } // Helper that returns a subset of `arrayAttr` as a vector of int64_t. @@ -795,8 +796,8 @@ auto loc = op.getLoc(); auto elemType = vType.getElementType(); - Value zero = rewriter.create(loc, elemType, - rewriter.getZeroAttr(elemType)); + Value zero = rewriter.create( + loc, elemType, rewriter.getZeroAttr(elemType)); Value desc = rewriter.create(loc, vType, zero); for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { Value extrLHS = rewriter.create(loc, op.lhs(), i); @@ -1144,11 +1145,11 @@ if (rank == 0) { switch (conversion) { case PrintConversion::ZeroExt64: - value = rewriter.create( + value = rewriter.create( loc, value, IntegerType::get(rewriter.getContext(), 64)); break; case PrintConversion::SignExt64: - value = rewriter.create( + value = rewriter.create( loc, value, IntegerType::get(rewriter.getContext(), 64)); break; case PrintConversion::None: @@ -1231,8 +1232,8 @@ } // Extract/insert on a lower ranked extract strided slice op. - Value zero = rewriter.create(loc, elemType, - rewriter.getZeroAttr(elemType)); + Value zero = rewriter.create( + loc, elemType, rewriter.getZeroAttr(elemType)); Value res = rewriter.create(loc, dstType, zero); for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; off += stride, ++idx) { diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/AMX/AMXDialect.h" #include "mlir/Dialect/AMX/Transforms.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms.h" @@ -42,6 +43,7 @@ // Override explicitly to allow conditional dialect dependence. void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); + registry.insert(); registry.insert(); if (enableArmNeon) registry.insert(); @@ -81,6 +83,7 @@ // Architecture specific augmentations. LLVMConversionTarget target(getContext()); + target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); target.addLegalOp(); diff --git a/mlir/lib/Conversion/VectorToSCF/CMakeLists.txt b/mlir/lib/Conversion/VectorToSCF/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToSCF/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToSCF/CMakeLists.txt @@ -8,6 +8,7 @@ Core LINK_LIBS PUBLIC + MLIRArithmetic MLIRLLVMIR MLIRMemRef MLIRTransforms diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -17,6 +17,7 @@ #include "../PassDetail.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Vector/VectorOps.h" @@ -123,8 +124,8 @@ return Value(); Location loc = xferOp.getLoc(); - Value ivI32 = - b.create(loc, IntegerType::get(b.getContext(), 32), iv); + Value ivI32 = b.create( + loc, IntegerType::get(b.getContext(), 32), iv); return b.create(loc, xferOp.mask(), ivI32); } @@ -171,13 +172,14 @@ bindDims(xferOp.getContext(), d0, d1); Value base = xferOp.indices()[dim.getValue()]; Value memrefIdx = makeComposedAffineApply(b, loc, d0 + d1, {base, iv}); - cond = lb.create(CmpIPredicate::sgt, memrefDim, memrefIdx); + cond = lb.create(arith::CmpIPredicate::sgt, memrefDim, + memrefIdx); } // Condition check 2: Masked in? if (auto maskCond = generateMaskCheck(b, xferOp, iv)) { if (cond) - cond = lb.create(cond, maskCond); + cond = lb.create(cond, maskCond); else cond = maskCond; } @@ -704,10 +706,10 @@ } // Loop bounds and step. - auto lb = locB.create(0); - auto ub = locB.create( + auto lb = locB.create(0); + auto ub = locB.create( castedDataType.getDimSize(castedDataType.getRank() - 1)); - auto step = locB.create(1); + auto step = locB.create(1); // TransferWriteOps that operate on tensors return the modified tensor and // require a loop state. auto loopState = Strategy::initialLoopState(xferOp); @@ -897,7 +899,7 @@ // Generate fully unrolled loop of transfer ops. Location loc = xferOp.getLoc(); for (int64_t i = 0; i < dimSize; ++i) { - Value iv = rewriter.create(loc, i); + Value iv = rewriter.create(loc, i); vec = generateInBoundsCheck( rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType), @@ -1023,7 +1025,7 @@ // Generate fully unrolled loop of transfer ops. Location loc = xferOp.getLoc(); for (int64_t i = 0; i < dimSize; ++i) { - Value iv = rewriter.create(loc, i); + Value iv = rewriter.create(loc, i); auto updatedSource = generateInBoundsCheck( rewriter, xferOp, iv, unpackedDim(xferOp), @@ -1114,8 +1116,8 @@ ValueRange loopState) { SmallVector indices; auto dim = get1dMemrefIndices(b, xferOp, iv, indices); - Value ivI32 = - b.create(loc, IntegerType::get(b.getContext(), 32), iv); + Value ivI32 = b.create( + loc, IntegerType::get(b.getContext(), 32), iv); auto vec = loopState[0]; // In case of out-of-bounds access, leave `vec` as is (was initialized with @@ -1147,8 +1149,8 @@ ValueRange /*loopState*/) { SmallVector indices; auto dim = get1dMemrefIndices(b, xferOp, iv, indices); - Value ivI32 = - b.create(loc, IntegerType::get(b.getContext(), 32), iv); + Value ivI32 = b.create( + loc, IntegerType::get(b.getContext(), 32), iv); // Nothing to do in case of out-of-bounds access. generateInBoundsCheck( @@ -1224,9 +1226,10 @@ // Loop bounds, step, state... Location loc = xferOp.getLoc(); auto vecType = xferOp.getVectorType(); - auto lb = rewriter.create(loc, 0); - auto ub = rewriter.create(loc, vecType.getDimSize(0)); - auto step = rewriter.create(loc, 1); + auto lb = rewriter.create(loc, 0); + auto ub = + rewriter.create(loc, vecType.getDimSize(0)); + auto step = rewriter.create(loc, 1); auto loopState = Strategy1d::initialLoopState(rewriter, xferOp); // Generate for loop. diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" @@ -221,7 +222,7 @@ Operation *AffineDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - return builder.create(loc, type, value); + return builder.create(loc, type, value); } /// A utility function to check if a value is defined at the top level of an @@ -1884,12 +1885,11 @@ buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn) { - auto lbConst = lb.getDefiningOp(); - auto ubConst = ub.getDefiningOp(); + auto lbConst = lb.getDefiningOp(); + auto ubConst = ub.getDefiningOp(); if (lbConst && ubConst) - return buildAffineLoopFromConstants(builder, loc, lbConst.getValue(), - ubConst.getValue(), step, - bodyBuilderFn); + return buildAffineLoopFromConstants(builder, loc, lbConst.value(), + ubConst.value(), step, bodyBuilderFn); return builder.create(loc, lb, builder.getDimIdentityMap(), ub, builder.getDimIdentityMap(), step, /*iterArgs=*/llvm::None, bodyBuilderFn); diff --git a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt @@ -11,6 +11,7 @@ MLIRAffineOpsIncGen LINK_LIBS PUBLIC + MLIRArithmetic MLIRIR MLIRLoopLikeInterface MLIRMemRef diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp @@ -23,6 +23,7 @@ #include "mlir/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -199,7 +200,7 @@ void AffineDataCopyGeneration::runOnFunction() { FuncOp f = getFunction(); OpBuilder topBuilder(f.getBody()); - zeroIndex = topBuilder.create(f.getLoc(), 0); + zeroIndex = topBuilder.create(f.getLoc(), 0); // Nests that are copy-in's or copy-out's; the root AffineForOps of those // nests are stored herein. diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp @@ -18,6 +18,7 @@ #include "mlir/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" @@ -81,7 +82,7 @@ } else if (isa(op)) { // TODO: Support DMA ops. return false; - } else if (!isa(op)) { + } else if (!isa(op)) { // Register op in the set of ops that have users. opsWithUsers.insert(&op); if (isa(op)) { diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt @@ -21,6 +21,7 @@ LINK_LIBS PUBLIC MLIRAffine MLIRAffineUtils + MLIRArithmetic MLIRIR MLIRMemRef MLIRPass diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -17,6 +17,7 @@ #include "mlir/Analysis/NestedMatcher.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/Vector/VectorUtils.h" #include "mlir/IR/BlockAndValueMapping.h" @@ -343,8 +344,8 @@ /// %A = alloc (%M, %N) : memref /// %B = alloc (%M, %N) : memref /// %C = alloc (%M, %N) : memref -/// %f1 = constant 1.0 : f32 -/// %f2 = constant 2.0 : f32 +/// %f1 = arith.constant 1.0 : f32 +/// %f2 = arith.constant 2.0 : f32 /// affine.for %i0 = 0 to %M { /// affine.for %i1 = 0 to %N { /// // non-scoped %f1 @@ -361,18 +362,18 @@ /// affine.for %i5 = 0 to %N { /// %a5 = affine.load %A[%i4, %i5] : memref /// %b5 = affine.load %B[%i4, %i5] : memref -/// %s5 = addf %a5, %b5 : f32 +/// %s5 = arith.addf %a5, %b5 : f32 /// // non-scoped %f1 -/// %s6 = addf %s5, %f1 : f32 +/// %s6 = arith.addf %s5, %f1 : f32 /// // non-scoped %f2 -/// %s7 = addf %s5, %f2 : f32 +/// %s7 = arith.addf %s5, %f2 : f32 /// // diamond dependency. -/// %s8 = addf %s7, %s6 : f32 +/// %s8 = arith.addf %s7, %s6 : f32 /// affine.store %s8, %C[%i4, %i5] : memref /// } /// } -/// %c7 = constant 7 : index -/// %c42 = constant 42 : index +/// %c7 = arith.constant 7 : index +/// %c42 = arith.constant 42 : index /// %res = load %C[%c7, %c42] : memref /// return %res : f32 /// } @@ -389,11 +390,11 @@ /// %0 = alloc(%arg0, %arg1) : memref /// %1 = alloc(%arg0, %arg1) : memref /// %2 = alloc(%arg0, %arg1) : memref -/// %cst = constant 1.0 : f32 -/// %cst_0 = constant 2.0 : f32 +/// %cst = arith.constant 1.0 : f32 +/// %cst_0 = arith.constant 2.0 : f32 /// affine.for %i0 = 0 to %arg0 { /// affine.for %i1 = 0 to %arg1 step 256 { -/// %cst_1 = constant dense, 1.0> : +/// %cst_1 = arith.constant dense, 1.0> : /// vector<256xf32> /// vector.transfer_write %cst_1, %0[%i0, %i1] : /// vector<256xf32>, memref @@ -401,7 +402,7 @@ /// } /// affine.for %i2 = 0 to %arg0 { /// affine.for %i3 = 0 to %arg1 step 256 { -/// %cst_2 = constant dense, 2.0> : +/// %cst_2 = arith.constant dense, 2.0> : /// vector<256xf32> /// vector.transfer_write %cst_2, %1[%i2, %i3] : /// vector<256xf32>, memref @@ -413,20 +414,20 @@ /// memref, vector<256xf32> /// %4 = vector.transfer_read %1[%i4, %i5] : /// memref, vector<256xf32> -/// %5 = addf %3, %4 : vector<256xf32> -/// %cst_3 = constant dense, 1.0> : +/// %5 = arith.addf %3, %4 : vector<256xf32> +/// %cst_3 = arith.constant dense, 1.0> : /// vector<256xf32> -/// %6 = addf %5, %cst_3 : vector<256xf32> -/// %cst_4 = constant dense, 2.0> : +/// %6 = arith.addf %5, %cst_3 : vector<256xf32> +/// %cst_4 = arith.constant dense, 2.0> : /// vector<256xf32> -/// %7 = addf %5, %cst_4 : vector<256xf32> -/// %8 = addf %7, %6 : vector<256xf32> +/// %7 = arith.addf %5, %cst_4 : vector<256xf32> +/// %8 = arith.addf %7, %6 : vector<256xf32> /// vector.transfer_write %8, %2[%i4, %i5] : /// vector<256xf32>, memref /// } /// } -/// %c7 = constant 7 : index -/// %c42 = constant 42 : index +/// %c7 = arith.constant 7 : index +/// %c42 = arith.constant 42 : index /// %9 = load %2[%c7, %c42] : memref /// return %9 : f32 /// } @@ -443,11 +444,11 @@ /// %0 = alloc(%arg0, %arg1) : memref /// %1 = alloc(%arg0, %arg1) : memref /// %2 = alloc(%arg0, %arg1) : memref -/// %cst = constant 1.0 : f32 -/// %cst_0 = constant 2.0 : f32 +/// %cst = arith.constant 1.0 : f32 +/// %cst_0 = arith.constant 2.0 : f32 /// affine.for %i0 = 0 to %arg0 step 32 { /// affine.for %i1 = 0 to %arg1 step 256 { -/// %cst_1 = constant dense, 1.0> : +/// %cst_1 = arith.constant dense, 1.0> : /// vector<32x256xf32> /// vector.transfer_write %cst_1, %0[%i0, %i1] : /// vector<32x256xf32>, memref @@ -455,7 +456,7 @@ /// } /// affine.for %i2 = 0 to %arg0 step 32 { /// affine.for %i3 = 0 to %arg1 step 256 { -/// %cst_2 = constant dense, 2.0> : +/// %cst_2 = arith.constant dense, 2.0> : /// vector<32x256xf32> /// vector.transfer_write %cst_2, %1[%i2, %i3] : /// vector<32x256xf32>, memref @@ -467,20 +468,20 @@ /// memref vector<32x256xf32> /// %4 = vector.transfer_read %1[%i4, %i5] : /// memref, vector<32x256xf32> -/// %5 = addf %3, %4 : vector<32x256xf32> -/// %cst_3 = constant dense, 1.0> : +/// %5 = arith.addf %3, %4 : vector<32x256xf32> +/// %cst_3 = arith.constant dense, 1.0> : /// vector<32x256xf32> -/// %6 = addf %5, %cst_3 : vector<32x256xf32> -/// %cst_4 = constant dense, 2.0> : +/// %6 = arith.addf %5, %cst_3 : vector<32x256xf32> +/// %cst_4 = arith.constant dense, 2.0> : /// vector<32x256xf32> -/// %7 = addf %5, %cst_4 : vector<32x256xf32> -/// %8 = addf %7, %6 : vector<32x256xf32> +/// %7 = arith.addf %5, %cst_4 : vector<32x256xf32> +/// %8 = arith.addf %7, %6 : vector<32x256xf32> /// vector.transfer_write %8, %2[%i4, %i5] : /// vector<32x256xf32>, memref /// } /// } -/// %c7 = constant 7 : index -/// %c42 = constant 42 : index +/// %c7 = arith.constant 7 : index +/// %c42 = arith.constant 42 : index /// %9 = load %2[%c7, %c42] : memref /// return %9 : f32 /// } @@ -510,11 +511,11 @@ /// Consider the following example: /// ```mlir /// func @vecred(%in: memref<512xf32>) -> f32 { -/// %cst = constant 0.000000e+00 : f32 +/// %cst = arith.constant 0.000000e+00 : f32 /// %sum = affine.for %i = 0 to 500 iter_args(%part_sum = %cst) -> (f32) { /// %ld = affine.load %in[%i] : memref<512xf32> /// %cos = math.cos %ld : f32 -/// %add = addf %part_sum, %cos : f32 +/// %add = arith.addf %part_sum, %cos : f32 /// affine.yield %add : f32 /// } /// return %sum : f32 @@ -530,18 +531,18 @@ /// ```mlir /// #map = affine_map<(d0) -> (-d0 + 500)> /// func @vecred(%arg0: memref<512xf32>) -> f32 { -/// %cst = constant 0.000000e+00 : f32 -/// %cst_0 = constant dense<0.000000e+00> : vector<128xf32> +/// %cst = arith.constant 0.000000e+00 : f32 +/// %cst_0 = arith.constant dense<0.000000e+00> : vector<128xf32> /// %0 = affine.for %arg1 = 0 to 500 step 128 iter_args(%arg2 = %cst_0) /// -> (vector<128xf32>) { /// // %2 is the number of iterations left in the original loop. /// %2 = affine.apply #map(%arg1) /// %3 = vector.create_mask %2 : vector<128xi1> -/// %cst_1 = constant 0.000000e+00 : f32 +/// %cst_1 = arith.constant 0.000000e+00 : f32 /// %4 = vector.transfer_read %arg0[%arg1], %cst_1 : /// memref<512xf32>, vector<128xf32> /// %5 = math.cos %4 : vector<128xf32> -/// %6 = addf %arg2, %5 : vector<128xf32> +/// %6 = arith.addf %arg2, %5 : vector<128xf32> /// // We filter out the effect of last 12 elements using the mask. /// %7 = select %3, %6, %arg2 : vector<128xi1>, vector<128xf32> /// affine.yield %7 : vector<128xf32> @@ -673,8 +674,8 @@ /// the vectorized operations. /// /// Example: - /// * 'replaced': %0 = addf %1, %2 : f32 - /// * 'replacement': %0 = addf %1, %2 : vector<128xf32> + /// * 'replaced': %0 = arith.addf %1, %2 : f32 + /// * 'replacement': %0 = arith.addf %1, %2 : vector<128xf32> void registerOpVectorReplacement(Operation *replaced, Operation *replacement); /// Registers the vector replacement of a scalar value. The replacement @@ -771,8 +772,8 @@ /// the vectorized operations. /// /// Example: -/// * 'replaced': %0 = addf %1, %2 : f32 -/// * 'replacement': %0 = addf %1, %2 : vector<128xf32> +/// * 'replaced': %0 = arith.addf %1, %2 : f32 +/// * 'replacement': %0 = arith.addf %1, %2 : vector<128xf32> void VectorizationState::registerOpVectorReplacement(Operation *replaced, Operation *replacement) { LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ commit vectorized op:\n"); @@ -940,14 +941,14 @@ /// Tries to transform a scalar constant into a vector constant. Returns the /// vector constant if the scalar type is valid vector element type. Returns /// nullptr, otherwise. -static ConstantOp vectorizeConstant(ConstantOp constOp, - VectorizationState &state) { +static arith::ConstantOp vectorizeConstant(arith::ConstantOp constOp, + VectorizationState &state) { Type scalarTy = constOp.getType(); if (!VectorType::isValidElementType(scalarTy)) return nullptr; auto vecTy = getVectorType(scalarTy, state.strategy); - auto vecAttr = DenseElementsAttr::get(vecTy, constOp.getValue()); + auto vecAttr = DenseElementsAttr::get(vecTy, constOp.value()); OpBuilder::InsertionGuard guard(state.builder); Operation *parentOp = state.builder.getInsertionBlock()->getParentOp(); @@ -958,7 +959,8 @@ isa(parentOp) && "Expected a vectorized for op"); auto vecForOp = cast(parentOp); state.builder.setInsertionPointToStart(vecForOp.getBody()); - auto newConstOp = state.builder.create(constOp.getLoc(), vecAttr); + auto newConstOp = + state.builder.create(constOp.getLoc(), vecAttr); // Register vector replacement for future uses in the scope. state.registerOpVectorReplacement(constOp, newConstOp); @@ -968,9 +970,9 @@ /// Creates a constant vector filled with the neutral elements of the given /// reduction. The scalar type of vector elements will be taken from /// `oldOperand`. -static ConstantOp createInitialVector(AtomicRMWKind reductionKind, - Value oldOperand, - VectorizationState &state) { +static arith::ConstantOp createInitialVector(AtomicRMWKind reductionKind, + Value oldOperand, + VectorizationState &state) { Type scalarTy = oldOperand.getType(); if (!VectorType::isValidElementType(scalarTy)) return nullptr; @@ -980,7 +982,7 @@ auto vecTy = getVectorType(scalarTy, state.strategy); auto vecAttr = DenseElementsAttr::get(vecTy, valueAttr); auto newConstOp = - state.builder.create(oldOperand.getLoc(), vecAttr); + state.builder.create(oldOperand.getLoc(), vecAttr); return newConstOp; } @@ -1120,8 +1122,8 @@ "Vector op not found in replacement map"); // Vectorize constant. - if (auto constOp = operand.getDefiningOp()) { - ConstantOp vecConstant = vectorizeConstant(constOp, state); + if (auto constOp = operand.getDefiningOp()) { + auto vecConstant = vectorizeConstant(constOp, state); LLVM_DEBUG(dbgs() << "-> constant: " << vecConstant); return vecConstant.getResult(); } @@ -1242,7 +1244,7 @@ return false; Attribute valueAttr = getIdentityValueAttr(reductionKind, scalarTy, state.builder, value.getLoc()); - if (auto constOp = dyn_cast_or_null(value.getDefiningOp())) + if (auto constOp = dyn_cast_or_null(value.getDefiningOp())) return constOp.value() == valueAttr; return false; } @@ -1417,7 +1419,7 @@ // being added to the accumulator by inserting `select` operations, for // example: // - // %res = addf %acc, %val : vector<128xf32> + // %res = arith.addf %acc, %val : vector<128xf32> // %res_masked = select %mask, %res, %acc : vector<128xi1>, vector<128xf32> // affine.yield %res_masked : vector<128xf32> // @@ -1464,7 +1466,7 @@ return vectorizeAffineForOp(forOp, state); if (auto yieldOp = dyn_cast(op)) return vectorizeAffineYieldOp(yieldOp, state); - if (auto constant = dyn_cast(op)) + if (auto constant = dyn_cast(op)) return vectorizeConstant(constant, state); // Other ops with regions are not supported. diff --git a/mlir/lib/Dialect/Arithmetic/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/CMakeLists.txt --- a/mlir/lib/Dialect/Arithmetic/CMakeLists.txt +++ b/mlir/lib/Dialect/Arithmetic/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticDialect.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticDialect.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticDialect.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticDialect.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/IR/Builders.h" #include "mlir/Transforms/InliningUtils.h" using namespace mlir; @@ -28,10 +29,18 @@ }; } // end anonymous namespace -void mlir::arith::ArithmeticDialect::initialize() { +void arith::ArithmeticDialect::initialize() { addOperations< #define GET_OP_LIST #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc" >(); addInterfaces(); } + +/// Materialize an integer or floating point constant. +Operation *arith::ArithmeticDialect::materializeConstant(OpBuilder &builder, + Attribute value, + Type type, + Location loc) { + return builder.create(loc, value, type); +} diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -75,6 +75,92 @@ #include "ArithmeticCanonicalization.inc" } // end anonymous namespace +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + +void arith::ConstantOp::getAsmResultNames( + function_ref setNameFn) { + auto type = getType(); + if (auto intCst = value().dyn_cast()) { + auto intType = type.dyn_cast(); + + // Sugar i1 constants with 'true' and 'false'. + if (intType && intType.getWidth() == 1) + return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); + + // Otherwise, build a compex name with the value and type. + SmallString<32> specialNameBuffer; + llvm::raw_svector_ostream specialName(specialNameBuffer); + specialName << 'c' << intCst.getInt(); + if (intType) + specialName << '_' << type; + setNameFn(getResult(), specialName.str()); + } else { + setNameFn(getResult(), "cst"); + } +} + +bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { + // The value's type must be the same as the provided type. + if (value.getType() != type) + return false; + // Integers values must be signless. + if (type.isa() && !type.cast().isSignless()) + return false; + // Integer, float, and element attributes are buildable. + return value.isa(); +} + +OpFoldResult arith::ConstantOp::fold(ArrayRef operands) { + return value(); +} + +void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, + int64_t value, unsigned width) { + auto type = builder.getIntegerType(width); + arith::ConstantOp::build(builder, result, type, + builder.getIntegerAttr(type, value)); +} + +void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, + int64_t value, Type type) { + assert(type.isSignlessInteger() && + "ConstantIntOp can only have signless integer type values"); + arith::ConstantOp::build(builder, result, type, + builder.getIntegerAttr(type, value)); +} + +bool arith::ConstantIntOp::classof(Operation *op) { + if (auto constOp = dyn_cast_or_null(op)) + return constOp.getType().isSignlessInteger(); + return false; +} + +void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result, + const APFloat &value, FloatType type) { + arith::ConstantOp::build(builder, result, type, + builder.getFloatAttr(type, value)); +} + +bool arith::ConstantFloatOp::classof(Operation *op) { + if (auto constOp = dyn_cast_or_null(op)) + return constOp.getType().isa(); + return false; +} + +void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result, + int64_t value) { + arith::ConstantOp::build(builder, result, builder.getIndexType(), + builder.getIndexAttr(value)); +} + +bool arith::ConstantIndexOp::classof(Operation *op) { + if (auto constOp = dyn_cast_or_null(op)) + return constOp.getType().isIndex(); + return false; +} + //===----------------------------------------------------------------------===// // AddIOp //===----------------------------------------------------------------------===// @@ -439,6 +525,49 @@ operands, [](APFloat a, APFloat b) { return a / b; }); } +//===----------------------------------------------------------------------===// +// Utility functions for verifying cast ops +//===----------------------------------------------------------------------===// + +template +using type_list = std::tuple *; + +/// Returns a non-null type only if the provided type is one of the allowed +/// types or one of the allowed shaped types of the allowed types. Returns the +/// element type if a valid shaped type is provided. +template +static Type getUnderlyingType(Type type, type_list, + type_list) { + if (type.isa() && !type.isa()) + return {}; + + auto underlyingType = getElementTypeOrSelf(type); + if (!underlyingType.isa()) + return {}; + + return underlyingType; +} + +/// Get allowed underlying types for vectors and tensors. +template +static Type getTypeIfLike(Type type) { + return getUnderlyingType(type, type_list(), + type_list()); +} + +/// Get allowed underlying types for vectors, tensors, and memrefs. +template +static Type getTypeIfLikeOrMemRef(Type type) { + return getUnderlyingType(type, + type_list(), + type_list()); +} + +static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) { + return inputs.size() == 1 && outputs.size() == 1 && + succeeded(verifyCompatibleShapes(inputs.front(), outputs.front())); +} + //===----------------------------------------------------------------------===// // Verifiers for integer and floating point extension/truncation ops //===----------------------------------------------------------------------===// @@ -469,6 +598,21 @@ return success(); } +/// Validate a cast that changes the width of a type. +template