diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md --- a/mlir/docs/Bindings/Python.md +++ b/mlir/docs/Bindings/Python.md @@ -442,7 +442,7 @@ attributes: * `OPERATION_NAME` attribute with the `str` fully qualified operation name - (i.e. `std.absf`). + (i.e. `math.abs`). * An `__init__` method for the *default builder* if one is defined or inferred for the operation. * `@property` getter for each operand or result (using an auto-generated name diff --git a/mlir/docs/BufferDeallocationInternals.md b/mlir/docs/BufferDeallocationInternals.md --- a/mlir/docs/BufferDeallocationInternals.md +++ b/mlir/docs/BufferDeallocationInternals.md @@ -405,7 +405,7 @@ ```mlir func @nested_region_control_flow(%arg0 : index, %arg1 : index) -> memref { - %0 = cmpi "eq", %arg0, %arg1 : index + %0 = arith.cmpi "eq", %arg0, %arg1 : index %1 = memref.alloc(%arg0, %arg0) : memref %2 = scf.if %0 -> (memref) { scf.yield %1 : memref // %2 will be an alias of %1 @@ -426,7 +426,7 @@ ```mlir func @nested_region_control_flow(%arg0: index, %arg1: index) -> memref { - %0 = cmpi "eq", %arg0, %arg1 : index + %0 = arith.cmpi "eq", %arg0, %arg1 : index %1 = memref.alloc(%arg0, %arg0) : memref %2 = scf.if %0 -> (memref) { scf.yield %1 : memref @@ -518,7 +518,7 @@ %res: memref<2xf32>) { %0 = scf.for %i = %lb to %ub step %step iter_args(%iterBuf = %buf) -> memref<2xf32> { - %1 = cmpi "eq", %i, %ub : index + %1 = arith.cmpi "eq", %i, %ub : index %2 = scf.if %1 -> (memref<2xf32>) { %3 = memref.alloc() : memref<2xf32> // makes %2 a critical alias due to a // divergent allocation @@ -557,7 +557,7 @@ %4 = memref.clone %buf : (memref<2xf32>) -> (memref<2xf32>) %0 = scf.for %i = %lb to %ub step %step iter_args(%iterBuf = %4) -> memref<2xf32> { - %1 = cmpi "eq", %i, %ub : index + %1 = arith.cmpi "eq", %i, %ub : index %2 = scf.if %1 -> (memref<2xf32>) { %3 = memref.alloc() : memref<2xf32> // makes %2 a critical alias use(%3) diff --git a/mlir/docs/Bufferization.md b/mlir/docs/Bufferization.md --- a/mlir/docs/Bufferization.md +++ b/mlir/docs/Bufferization.md @@ -199,46 +199,50 @@ ### Other partial bufferization examples -- `linalg-bufferize` - ([code](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp#L1), - [test](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/test/Dialect/Linalg/bufferize.mlir#L1)) - - - Bufferizes the `linalg` dialect. - - This is an example of how to simultaneously bufferize all the ops that - satisfy a certain OpInterface with a single pattern. Specifically, - `BufferizeAnyLinalgOp` - ([code](https://github.com/llvm/llvm-project/blob/daaaed6bb89044ac58a23f1bb1ccdd12342a5a58/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp#L170)) - bufferizes any ops that implements the `LinalgOp` interface. - -- `scf-bufferize` - ([code](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp#L1), - [test](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/test/Dialect/SCF/bufferize.mlir#L1)) - - - Bufferizes ops from the `scf` dialect. - - This is an example of how to bufferize ops that implement - `RegionBranchOpInterface` (that is, they use regions to represent control - flow). - - The bulk of the work is done by - `lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp` - ([code](https://github.com/llvm/llvm-project/blob/daaaed6bb89044ac58a23f1bb1ccdd12342a5a58/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp#L1)), - which is well-commented and covers how to correctly convert ops that contain - regions. - -- `func-bufferize` - ([code](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp#L1), - [test](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/test/Dialect/Standard/func-bufferize.mlir#L1)) - - - Bufferizes `func`, `call`, and `BranchOpInterface` ops. - - This is an example of how to bufferize ops that have multi-block regions. - - This is an example of a pass that is not split along dialect subdivisions. - -- `tensor-constant-bufferize` - ([code](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp#L1), - [test](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir#L1)) - - Bufferizes only `std.constant` ops of `tensor` type. - - This is an example of setting up the legality so that only a subset of - `std.constant` ops get bufferized. - - This is an example of a pass that is not split along dialect subdivisions. +- `linalg-bufferize` + ([code](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp#L1), + [test](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/test/Dialect/Linalg/bufferize.mlir#L1)) + + - Bufferizes the `linalg` dialect. + - This is an example of how to simultaneously bufferize all the ops that + satisfy a certain OpInterface with a single pattern. Specifically, + `BufferizeAnyLinalgOp` + ([code](https://github.com/llvm/llvm-project/blob/daaaed6bb89044ac58a23f1bb1ccdd12342a5a58/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp#L170)) + bufferizes any ops that implements the `LinalgOp` interface. + +- `scf-bufferize` + ([code](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp#L1), + [test](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/test/Dialect/SCF/bufferize.mlir#L1)) + + - Bufferizes ops from the `scf` dialect. + - This is an example of how to bufferize ops that implement + `RegionBranchOpInterface` (that is, they use regions to represent + control flow). + - The bulk of the work is done by + `lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp` + ([code](https://github.com/llvm/llvm-project/blob/daaaed6bb89044ac58a23f1bb1ccdd12342a5a58/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp#L1)), + which is well-commented and covers how to correctly convert ops that + contain regions. + +- `func-bufferize` + ([code](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp#L1), + [test](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/test/Dialect/Standard/func-bufferize.mlir#L1)) + + - Bufferizes `func`, `call`, and `BranchOpInterface` ops. + - This is an example of how to bufferize ops that have multi-block + regions. + - This is an example of a pass that is not split along dialect + subdivisions. + +- `tensor-constant-bufferize` + ([code](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp#L1), + [test](https://github.com/llvm/llvm-project/blob/bc8acf2ce8ad6e8c9b1d97b2e02d3f4ad26e1d9d/mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir#L1)) + + - Bufferizes only `arith.const` ops of `tensor` type. + - This is an example of setting up the legality so that only a subset of + `std.constant` ops get bufferized. + - This is an example of a pass that is not split along dialect + subdivisions. ## How to write a finalizing bufferization pass diff --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md --- a/mlir/docs/DeclarativeRewrites.md +++ b/mlir/docs/DeclarativeRewrites.md @@ -486,7 +486,7 @@ want to allocate memory and store some computation (in pseudocode): ```mlir -%dst = addi %lhs, %rhs +%dst = arith.addi %lhs, %rhs ``` into @@ -494,7 +494,7 @@ ```mlir %shape = shape %lhs %mem = alloc %shape -%sum = addi %lhs, %rhs +%sum = arith.addi %lhs, %rhs store %mem, %sum %dst = load %mem ``` diff --git a/mlir/docs/Diagnostics.md b/mlir/docs/Diagnostics.md --- a/mlir/docs/Diagnostics.md +++ b/mlir/docs/Diagnostics.md @@ -301,7 +301,7 @@ // Expect an error on an adjacent line. func @foo(%a : f32) { // expected-error@+1 {{unknown comparison predicate "foo"}} - %result = cmpf "foo", %a, %a : f32 + %result = arith.cmpf "foo", %a, %a : f32 return } diff --git a/mlir/docs/Dialects/Linalg/_index.md b/mlir/docs/Dialects/Linalg/_index.md --- a/mlir/docs/Dialects/Linalg/_index.md +++ b/mlir/docs/Dialects/Linalg/_index.md @@ -314,7 +314,7 @@ ins(%A, %B: memref, memref) outs(%C: memref) { ^bb0(%a: f32, %b: f32, %c: f32): - %d = addf %a, %b : f32 + %d = arith.addf %a, %b : f32 linalg.yield %d : f32 } @@ -338,7 +338,7 @@ scf.for %arg4 = %c0 to %1 step %c1 { %2 = load %arg0[%arg3, %arg4] : memref %3 = load %arg1[%arg3, %arg4] : memref - %4 = addf %2, %3 : f32 + %4 = arith.addf %2, %3 : f32 store %4, %arg2[%arg3, %arg4] : memref } } @@ -387,7 +387,7 @@ ins(%A, %B: memref, memref) outs(%C: memref) { ^bb0(%a: f32, %b: f32, %c: f32): - %d = addf %a, %b : f32 + %d = arith.addf %a, %b : f32 linalg.yield %d : f32 } return diff --git a/mlir/docs/Dialects/Vector.md b/mlir/docs/Dialects/Vector.md --- a/mlir/docs/Dialects/Vector.md +++ b/mlir/docs/Dialects/Vector.md @@ -95,8 +95,8 @@ ### Virtual Vector Ops Some existing Standard and Vector Dialect on `n-D` `vector` types comprise: ``` -%2 = std.addf %0, %1 : vector<3x7x8xf32> // -> vector<3x7x8xf32> -%2 = std.mulf %0, %1 : vector<3x7x8xf32> // -> vector<3x7x8xf32> +%2 = arith.addf %0, %1 : vector<3x7x8xf32> // -> vector<3x7x8xf32> +%2 = arith.mulf %0, %1 : vector<3x7x8xf32> // -> vector<3x7x8xf32> %2 = std.splat %1 : vector<3x7x8xf32> // -> vector<3x7x8xf32> %1 = vector.extract %0[1]: vector<3x7x8xf32> // -> vector<7x8xf32> diff --git a/mlir/docs/Dialects/emitc.md b/mlir/docs/Dialects/emitc.md --- a/mlir/docs/Dialects/emitc.md +++ b/mlir/docs/Dialects/emitc.md @@ -23,13 +23,15 @@ Besides operations part of the EmitC dialect, the Cpp targets supports translating the following operations: -* 'std' Dialect - * `std.br` - * `std.call` - * `std.cond_br` - * `std.constant` - * `std.return` -* 'scf' Dialect - * `scf.for` - * `scf.if` - * `scf.yield` +* 'std' Dialect + * `std.br` + * `std.call` + * `std.cond_br` + * `std.constant` + * `std.return` +* 'scf' Dialect + * `scf.for` + * `scf.if` + * `scf.yield` +* 'arith' Dialect + * 'arith.const' diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md --- a/mlir/docs/LangRef.md +++ b/mlir/docs/LangRef.md @@ -120,9 +120,9 @@ affine.for %k = 0 to %n { %a_v = load %A[%i, %k] : memref<100x?xf32> %b_v = load %B[%k, %j] : memref - %prod = mulf %a_v, %b_v : f32 + %prod = arith.mulf %a_v, %b_v : f32 %c_v = load %C[%i, %j] : memref<100x50xf32> - %sum = addf %c_v, %prod : f32 + %sum = arith.addf %c_v, %prod : f32 store %sum, %C[%i, %j] : memref<100x50xf32> } } @@ -389,7 +389,7 @@ br ^bb3(%a: i64) // Branch passes %a as the argument ^bb2: - %b = addi %a, %a : i64 + %b = arith.addi %a, %a : i64 br ^bb3(%b: i64) // Branch passes %b as the argument // ^bb3 receives an argument, named %c, from predecessors @@ -400,7 +400,7 @@ br ^bb4(%c, %a : i64, i64) ^bb4(%d : i64, %e : i64): - %0 = addi %d, %e : i64 + %0 = arith.addi %d, %e : i64 return %0 : i64 // Return is also a terminator. } ``` @@ -756,7 +756,7 @@ - *inherent attributes* are inherent to the definition of an operation's semantics. The operation itself is expected to verify the consistency of these - attributes. An example is the `predicate` attribute of the `std.cmpi` op. + attributes. An example is the `predicate` attribute of the `arith.cmpi` op. These attributes must have names that do not start with a dialect prefix. - *discardable attributes* have semantics defined externally to the operation diff --git a/mlir/docs/Rationale/MLIRForGraphAlgorithms.md b/mlir/docs/Rationale/MLIRForGraphAlgorithms.md --- a/mlir/docs/Rationale/MLIRForGraphAlgorithms.md +++ b/mlir/docs/Rationale/MLIRForGraphAlgorithms.md @@ -156,7 +156,7 @@ ```mlir // RUN: mlir-opt %s -canonicalize | FileCheck %s func @test_subi_zero_cfg(%arg0: i32) -> i32 { - %y = subi %arg0, %arg0 : i32 + %y = arith.subi %arg0, %arg0 : i32 return %y: i32 } // CHECK-LABEL: func @test_subi_zero_cfg(%arg0: i32) diff --git a/mlir/docs/Rationale/Rationale.md b/mlir/docs/Rationale/Rationale.md --- a/mlir/docs/Rationale/Rationale.md +++ b/mlir/docs/Rationale/Rationale.md @@ -560,12 +560,12 @@ br ^bb1(0) ^bb1(%j: i32) - %p1 = cmpi "lt", %j, %nj : i32 + %p1 = arith.cmpi "lt", %j, %nj : i32 cond_br %p1, ^bb2, ^bb5 ^bb2: %v = affine.load %A[%i, %j] : memref - %p2 = cmpi "eq", %v, %key : i32 + %p2 = arith.cmpi "eq", %v, %key : i32 cond_br %p2, ^bb3(%j), ^bb4 ^bb3(%j: i32) @@ -573,7 +573,7 @@ br ^bb5 ^bb4: - %jinc = addi %j, 1 : i32 + %jinc = arith.addi %j, 1 : i32 br ^bb1(%jinc) ^bb5: @@ -844,7 +844,7 @@ bb0 (%0, %1: memref<128xf32>, i64): %val = affine.load %A [%pos] %val = affine.load %A [%pos + 1] - %p = mulf %val, %val : f32 + %p = arith.mulf %val, %val : f32 return %p : f32 } ``` diff --git a/mlir/docs/SymbolsAndSymbolTables.md b/mlir/docs/SymbolsAndSymbolTables.md --- a/mlir/docs/SymbolsAndSymbolTables.md +++ b/mlir/docs/SymbolsAndSymbolTables.md @@ -137,9 +137,9 @@ different trade offs depending on the situation. A function call may directly use a `SymbolRef` as the callee, whereas a reference to a global variable might use a materialization operation so that the variable can be used in other -operations like `std.addi`. -[`llvm.mlir.addressof`](Dialects/LLVM.md/#llvmmliraddressof-mlirllvmaddressofop) is one example of -such an operation. +operations like `arith.addi`. +[`llvm.mlir.addressof`](Dialects/LLVM.md/#llvmmliraddressof-mlirllvmaddressofop) +is one example of such an operation. See the `LangRef` definition of the [`SymbolRefAttr`](Dialects/Builtin.md/#symbolrefattr) for more information diff --git a/mlir/docs/TargetLLVMIR.md b/mlir/docs/TargetLLVMIR.md --- a/mlir/docs/TargetLLVMIR.md +++ b/mlir/docs/TargetLLVMIR.md @@ -779,27 +779,27 @@ // dynamic, extract the stride value from the descriptor. %stride1 = llvm.extractvalue[4, 0] : !llvm.struct<(ptr, ptr, i64, array<4xi64>, array<4xi64>)> -%addr1 = muli %stride1, %1 : i64 +%addr1 = arith.muli %stride1, %1 : i64 // When the stride or, in absence of explicit strides, the trailing sizes are // known statically, this value is used as a constant. The natural value of // strides is the product of all sizes following the current dimension. %stride2 = llvm.mlir.constant(32 : index) : i64 -%addr2 = muli %stride2, %2 : i64 -%addr3 = addi %addr1, %addr2 : i64 +%addr2 = arith.muli %stride2, %2 : i64 +%addr3 = arith.addi %addr1, %addr2 : i64 %stride3 = llvm.mlir.constant(8 : index) : i64 -%addr4 = muli %stride3, %3 : i64 -%addr5 = addi %addr3, %addr4 : i64 +%addr4 = arith.muli %stride3, %3 : i64 +%addr5 = arith.addi %addr3, %addr4 : i64 // Multiplication with the known unit stride can be omitted. -%addr6 = addi %addr5, %4 : i64 +%addr6 = arith.addi %addr5, %4 : i64 // If the linear offset is known to be zero, it can also be omitted. If it is // dynamic, it is extracted from the descriptor. %offset = llvm.extractvalue[2] : !llvm.struct<(ptr, ptr, i64, array<4xi64>, array<4xi64>)> -%addr7 = addi %addr6, %offset : i64 +%addr7 = arith.addi %addr6, %offset : i64 // All accesses are based on the aligned pointer. %aligned = llvm.extractvalue[1] : !llvm.struct<(ptr, ptr, i64, diff --git a/mlir/docs/Tutorials/Toy/Ch-5.md b/mlir/docs/Tutorials/Toy/Ch-5.md --- a/mlir/docs/Tutorials/Toy/Ch-5.md +++ b/mlir/docs/Tutorials/Toy/Ch-5.md @@ -275,7 +275,7 @@ affine.for %arg1 = 0 to 2 { %3 = affine.load %1[%arg0, %arg1] : memref<3x2xf64> %4 = affine.load %1[%arg0, %arg1] : memref<3x2xf64> - %5 = mulf %3, %4 : f64 + %5 = arith.mulf %3, %4 : f64 affine.store %5, %0[%arg0, %arg1] : memref<3x2xf64> } } @@ -324,7 +324,7 @@ %2 = affine.load %1[%arg1, %arg0] : memref<2x3xf64> // Multiply and store into the output buffer. - %3 = mulf %2, %2 : f64 + %3 = arith.mulf %2, %2 : f64 affine.store %3, %0[%arg0, %arg1] : memref<3x2xf64> } } diff --git a/mlir/docs/Tutorials/Toy/Ch-6.md b/mlir/docs/Tutorials/Toy/Ch-6.md --- a/mlir/docs/Tutorials/Toy/Ch-6.md +++ b/mlir/docs/Tutorials/Toy/Ch-6.md @@ -209,13 +209,13 @@ 115: %116 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %24, 0 - %117 = bitcast double* %116 to i8* + %117 = arith.bitcast double* %116 to i8* call void @free(i8* %117) %118 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %16, 0 - %119 = bitcast double* %118 to i8* + %119 = arith.bitcast double* %118 to i8* call void @free(i8* %119) %120 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %8, 0 - %121 = bitcast double* %120 to i8* + %121 = arith.bitcast double* %120 to i8* call void @free(i8* %121) ret void } diff --git a/mlir/examples/standalone/standalone-opt/CMakeLists.txt b/mlir/examples/standalone/standalone-opt/CMakeLists.txt --- a/mlir/examples/standalone/standalone-opt/CMakeLists.txt +++ b/mlir/examples/standalone/standalone-opt/CMakeLists.txt @@ -3,6 +3,7 @@ set(LIBS ${dialect_libs} ${conversion_libs} + MLIRArithmetic MLIROptLib MLIRStandalone ) diff --git a/mlir/examples/standalone/standalone-opt/standalone-opt.cpp b/mlir/examples/standalone/standalone-opt/standalone-opt.cpp --- a/mlir/examples/standalone/standalone-opt/standalone-opt.cpp +++ b/mlir/examples/standalone/standalone-opt/standalone-opt.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/MLIRContext.h" #include "mlir/InitAllDialects.h" @@ -27,6 +28,7 @@ mlir::DialectRegistry registry; registry.insert(); + registry.insert(); registry.insert(); // Add the following to include *all* MLIR Core dialects, or selectively // include what you need like above. You only need to register dialects that diff --git a/mlir/examples/standalone/test/Standalone/dummy.mlir b/mlir/examples/standalone/test/Standalone/dummy.mlir --- a/mlir/examples/standalone/test/Standalone/dummy.mlir +++ b/mlir/examples/standalone/test/Standalone/dummy.mlir @@ -3,7 +3,7 @@ module { // CHECK-LABEL: func @bar() func @bar() { - %0 = constant 1 : i32 + %0 = arith.const 1 : i32 // CHECK: %{{.*}} = standalone.foo %{{.*}} : i32 %res = standalone.foo %0 : i32 return diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp @@ -16,6 +16,7 @@ #include "toy/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Pass/Pass.h" @@ -124,8 +125,8 @@ return success(); } }; -using AddOpLowering = BinaryOpLowering; -using MulOpLowering = BinaryOpLowering; +using AddOpLowering = BinaryOpLowering; +using MulOpLowering = BinaryOpLowering; //===----------------------------------------------------------------------===// // ToyToAffine RewritePatterns: Constant operations @@ -154,10 +155,12 @@ if (!valueShape.empty()) { for (auto i : llvm::seq( 0, *std::max_element(valueShape.begin(), valueShape.end()))) - constantIndices.push_back(rewriter.create(loc, i)); + constantIndices.push_back( + rewriter.create(loc, i)); } else { // This is the case of a tensor of rank 0. - constantIndices.push_back(rewriter.create(loc, 0)); + constantIndices.push_back( + rewriter.create(loc, 0)); } // The constant operation represents a multi-dimensional constant, so we @@ -171,7 +174,7 @@ // we store the element at the given index. if (dimension == valueShape.size()) { rewriter.create( - loc, rewriter.create(loc, *valueIt++), alloc, + loc, rewriter.create(loc, *valueIt++), alloc, llvm::makeArrayRef(indices)); return; } @@ -284,9 +287,9 @@ // We define the specific operations, or dialects, that are legal targets for // this lowering. In our case, we are lowering to a combination of the - // `Affine`, `MemRef` and `Standard` dialects. - target.addLegalDialect(); + // `Affine`, `Arithmetic`, `MemRef`, and `Standard` dialects. + target.addLegalDialect(); // We also define the Toy dialect as Illegal so that the conversion will fail // if any of these operations are *not* converted. Given that we actually want diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp @@ -16,6 +16,7 @@ #include "toy/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Pass/Pass.h" @@ -124,8 +125,8 @@ return success(); } }; -using AddOpLowering = BinaryOpLowering; -using MulOpLowering = BinaryOpLowering; +using AddOpLowering = BinaryOpLowering; +using MulOpLowering = BinaryOpLowering; //===----------------------------------------------------------------------===// // ToyToAffine RewritePatterns: Constant operations @@ -154,10 +155,12 @@ if (!valueShape.empty()) { for (auto i : llvm::seq( 0, *std::max_element(valueShape.begin(), valueShape.end()))) - constantIndices.push_back(rewriter.create(loc, i)); + constantIndices.push_back( + rewriter.create(loc, i)); } else { // This is the case of a tensor of rank 0. - constantIndices.push_back(rewriter.create(loc, 0)); + constantIndices.push_back( + rewriter.create(loc, 0)); } // The constant operation represents a multi-dimensional constant, so we // will need to generate a store for each of the elements. The following @@ -170,7 +173,7 @@ // we store the element at the given index. if (dimension == valueShape.size()) { rewriter.create( - loc, rewriter.create(loc, *valueIt++), alloc, + loc, rewriter.create(loc, *valueIt++), alloc, llvm::makeArrayRef(indices)); return; } @@ -283,9 +286,9 @@ // We define the specific operations, or dialects, that are legal targets for // this lowering. In our case, we are lowering to a combination of the - // `Affine`, `MemRef` and `Standard` dialects. - target.addLegalDialect(); + // `Affine`, `Arithmetic`, `MemRef`, and `Standard` dialects. + target.addLegalDialect(); // We also define the Toy dialect as Illegal so that the conversion will fail // if any of these operations are *not* converted. Given that we actually want diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp --- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp @@ -25,6 +25,7 @@ #include "toy/Passes.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" @@ -32,6 +33,7 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" @@ -73,9 +75,10 @@ // Create a loop for each of the dimensions within the shape. SmallVector loopIvs; for (unsigned i = 0, e = memRefShape.size(); i != e; ++i) { - auto lowerBound = rewriter.create(loc, 0); - auto upperBound = rewriter.create(loc, memRefShape[i]); - auto step = rewriter.create(loc, 1); + auto lowerBound = rewriter.create(loc, 0); + auto upperBound = + rewriter.create(loc, memRefShape[i]); + auto step = rewriter.create(loc, 1); auto loop = rewriter.create(loc, lowerBound, upperBound, step); for (Operation &nested : *loop.getBody()) @@ -198,6 +201,7 @@ RewritePatternSet patterns(&getContext()); populateAffineToStdConversionPatterns(patterns); populateLoopToStdConversionPatterns(patterns); + populateArithmeticToLLVMConversionPatterns(typeConverter, patterns); populateMemRefToLLVMConversionPatterns(typeConverter, patterns); populateStdToLLVMConversionPatterns(typeConverter, patterns); diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp @@ -16,6 +16,7 @@ #include "toy/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Pass/Pass.h" @@ -124,8 +125,8 @@ return success(); } }; -using AddOpLowering = BinaryOpLowering; -using MulOpLowering = BinaryOpLowering; +using AddOpLowering = BinaryOpLowering; +using MulOpLowering = BinaryOpLowering; //===----------------------------------------------------------------------===// // ToyToAffine RewritePatterns: Constant operations @@ -154,10 +155,12 @@ if (!valueShape.empty()) { for (auto i : llvm::seq( 0, *std::max_element(valueShape.begin(), valueShape.end()))) - constantIndices.push_back(rewriter.create(loc, i)); + constantIndices.push_back( + rewriter.create(loc, i)); } else { // This is the case of a tensor of rank 0. - constantIndices.push_back(rewriter.create(loc, 0)); + constantIndices.push_back( + rewriter.create(loc, 0)); } // The constant operation represents a multi-dimensional constant, so we @@ -171,7 +174,7 @@ // we store the element at the given index. if (dimension == valueShape.size()) { rewriter.create( - loc, rewriter.create(loc, *valueIt++), alloc, + loc, rewriter.create(loc, *valueIt++), alloc, llvm::makeArrayRef(indices)); return; } @@ -284,9 +287,9 @@ // We define the specific operations, or dialects, that are legal targets for // this lowering. In our case, we are lowering to a combination of the - // `Affine`, `MemRef` and `Standard` dialects. - target.addLegalDialect(); + // `Affine`, `Arithmetic`, `MemRef`, and `Standard` dialects. + target.addLegalDialect(); // We also define the Toy dialect as Illegal so that the conversion will fail // if any of these operations are *not* converted. Given that we actually want diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -25,6 +25,7 @@ #include "toy/Passes.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" @@ -32,6 +33,7 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" @@ -73,9 +75,10 @@ // Create a loop for each of the dimensions within the shape. SmallVector loopIvs; for (unsigned i = 0, e = memRefShape.size(); i != e; ++i) { - auto lowerBound = rewriter.create(loc, 0); - auto upperBound = rewriter.create(loc, memRefShape[i]); - auto step = rewriter.create(loc, 1); + auto lowerBound = rewriter.create(loc, 0); + auto upperBound = + rewriter.create(loc, memRefShape[i]); + auto step = rewriter.create(loc, 1); auto loop = rewriter.create(loc, lowerBound, upperBound, step); for (Operation &nested : *loop.getBody()) @@ -198,6 +201,7 @@ RewritePatternSet patterns(&getContext()); populateAffineToStdConversionPatterns(patterns); populateLoopToStdConversionPatterns(patterns); + populateArithmeticToLLVMConversionPatterns(typeConverter, patterns); populateMemRefToLLVMConversionPatterns(typeConverter, patterns); populateStdToLLVMConversionPatterns(typeConverter, patterns); diff --git a/mlir/include/mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h b/mlir/include/mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h @@ -0,0 +1,27 @@ +//===- ArithmeticToLLVM.h - Arith to LLVM dialect conversion ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARITHMETICTOLLVM_ARITHMETICTOLLVM_H +#define MLIR_CONVERSION_ARITHMETICTOLLVM_ARITHMETICTOLLVM_H + +#include + +namespace mlir { + +class LLVMTypeConverter; +class RewritePatternSet; +class Pass; + +void populateArithmeticToLLVMConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); + +std::unique_ptr createConvertArithmeticToLLVMPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_ARITHMETICTOLLVM_ARITHMETICTOLLVM_H diff --git a/mlir/include/mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h b/mlir/include/mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h @@ -0,0 +1,22 @@ +//===- ArithmeticToSPIRV.h - Convert Arith to SPIRV dialect -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARITHMETICTOSPIRV_ARITHMETICTOSPIRV_H +#define MLIR_CONVERSION_ARITHMETICTOSPIRV_ARITHMETICTOSPIRV_H + +namespace mlir { + +class SPIRVTypeConverter; +class RewritePatternSet; + +void populateArithmeticToSPIRVPatterns(SPIRVTypeConverter &typeConverter, + RewritePatternSet &patterns); + +} // end namespace mlir + +#endif // MLIR_CONVERSION_ARITHMETICTOSPIRV_ARITHMETICTOSPIRV_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -10,6 +10,7 @@ #define MLIR_CONVERSION_PASSES_H #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h" #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -39,10 +39,10 @@ %d0 = <...> %d1 = <...> %s0 = <...> - %0 = constant 2 : index - %1 = muli %0, %d1 - %2 = addi %d0, %1 - %r = addi %2, %s0 + %0 = arith.const 2 : index + %1 = arith.muli %0, %d1 + %2 = arith.addi %d0, %1 + %r = arith.addi %2, %s0 ``` #### Input invariant @@ -67,6 +67,7 @@ }]; let constructor = "mlir::createLowerAffinePass()"; let dependentDialects = [ + "arith::ArithmeticDialect", "memref::MemRefDialect", "scf::SCFDialect", "StandardOpsDialect", @@ -74,6 +75,19 @@ ]; } +//===----------------------------------------------------------------------===// +// ArithmeticToLLVM +//===----------------------------------------------------------------------===// + +def ConvertArithmeticToLLVM : FunctionPass<"convert-arithmetic-to-llvm"> { + let summary = "Convert Arithmetic dialect to LLVM dialect"; + let description = [{ + This pass converts supported Arithmetic ops to LLVM dialect instructions. + }]; + let constructor = "mlir::createConvertArithmeticToLLVMPass()"; + let dependentDialects = ["LLVM::LLVMDialect"]; +} + //===----------------------------------------------------------------------===// // AsyncToLLVM //===----------------------------------------------------------------------===// @@ -86,7 +100,10 @@ API to execute them. }]; let constructor = "mlir::createConvertAsyncToLLVMPass()"; - let dependentDialects = ["LLVM::LLVMDialect"]; + let dependentDialects = [ + "arith::ArithmeticDialect", + "LLVM::LLVMDialect", + ]; } //===----------------------------------------------------------------------===// @@ -107,6 +124,7 @@ let summary = "Convert Complex dialect to standard dialect"; let constructor = "mlir::createConvertComplexToStandardPass()"; let dependentDialects = [ + "arith::ArithmeticDialect", "complex::ComplexDialect", "math::MathDialect", "StandardOpsDialect" @@ -136,7 +154,12 @@ def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> { let summary = "Generate NVVM operations for gpu operations"; let constructor = "mlir::createLowerGpuOpsToNVVMOpsPass()"; - let dependentDialects = ["NVVM::NVVMDialect", "memref::MemRefDialect"]; + let dependentDialects = [ + "arith::ArithmeticDialect", + "memref::MemRefDialect", + "NVVM::NVVMDialect", + "StandardOpsDialect", + ]; let options = [ Option<"indexBitwidth", "index-bitwidth", "unsigned", /*default=kDeriveIndexBitwidthFromDataLayout*/"0", @@ -252,7 +275,11 @@ This pass converts supported Math ops to libm calls. }]; let constructor = "mlir::createConvertMathToLibmPass()"; - let dependentDialects = ["StandardOpsDialect", "vector::VectorDialect"]; + let dependentDialects = [ + "arith::ArithmeticDialect", + "StandardOpsDialect", + "vector::VectorDialect", + ]; } //===----------------------------------------------------------------------===// @@ -412,7 +439,7 @@ let summary = "Convert SCF dialect to Standard dialect, replacing structured" " control flow with a CFG"; let constructor = "mlir::createLowerToCFGPass()"; - let dependentDialects = ["StandardOpsDialect"]; + let dependentDialects = ["arith::ArithmeticDialect", "StandardOpsDialect"]; } //===----------------------------------------------------------------------===// @@ -422,7 +449,7 @@ def ConvertAffineForToGPU : FunctionPass<"convert-affine-for-to-gpu"> { let summary = "Convert top-level AffineFor Ops to GPU kernels"; let constructor = "mlir::createAffineForToGPUPass()"; - let dependentDialects = ["gpu::GPUDialect"]; + let dependentDialects = ["arith::ArithmeticDialect, gpu::GPUDialect"]; let options = [ Option<"numBlockDims", "gpu-block-dims", "unsigned", /*default=*/"1u", "Number of GPU block dimensions for mapping">, @@ -446,6 +473,7 @@ "dialect"; let constructor = "mlir::createConvertShapeToStandardPass()"; let dependentDialects = [ + "arith::ArithmeticDialect", "StandardOpsDialect", "scf::SCFDialect", "tensor::TensorDialect" @@ -583,7 +611,11 @@ def TosaToStandard : Pass<"tosa-to-standard"> { let summary = "Lower TOSA to the Standard dialect"; - let dependentDialects = ["StandardOpsDialect", "tensor::TensorDialect"]; + let dependentDialects = [ + "arith::ArithmeticDialect", + "StandardOpsDialect", + "tensor::TensorDialect", + ]; let description = [{ Pass that converts TOSA operations to the equivalent operations using the operations in the Standard dialect. diff --git a/mlir/include/mlir/Conversion/SPIRVCommon/Pattern.h b/mlir/include/mlir/Conversion/SPIRVCommon/Pattern.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/SPIRVCommon/Pattern.h @@ -0,0 +1,44 @@ +//===- Pattern.h - SPIRV Common Conversion Patterns -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_SPIRVCOMMON_PATTERN_H +#define MLIR_CONVERSION_SPIRVCOMMON_PATTERN_H + +#include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace spirv { + +/// Converts unary and binary standard operations to SPIR-V operations. +template +class UnaryAndBinaryOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(Op op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + assert(operands.size() <= 2); + auto dstType = this->getTypeConverter()->convertType(op.getType()); + if (!dstType) + return failure(); + if (SPIRVOp::template hasTrait() && + dstType != op.getType()) { + return op.emitError( + "bitwidth emulation is not implemented yet on unsigned op"); + } + rewriter.template replaceOpWithNewOp(op, dstType, operands); + return success(); + } +}; + +} // end namespace spirv +} // end namespace mlir + +#endif // MLIR_CONVERSION_SPIRVCOMMON_PATTERN_H diff --git a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h --- a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h +++ b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h @@ -37,7 +37,7 @@ /// affine.for %I = 0 to 9 { /// %dim = dim %A, 0 : memref /// %add = affine.apply %I + %a -/// %cmp = cmpi "slt", %add, %dim : index +/// %cmp = arith.cmpi "slt", %add, %dim : index /// scf.if %cmp { /// %vec_2d = load %1[%I] : memref<9xvector<17x15xf32>> /// vector.transfer_write %vec_2d, %A[%add, %b, %c] : diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -201,7 +201,7 @@ %sum = affine.for %i = 0 to 10 step 2 iter_args(%sum_iter = %sum_0) -> (f32) { %t = affine.load %buffer[%i] : memref<1024xf32> - %sum_next = addf %sum_iter, %t : f32 + %sum_next = arith.addf %sum_iter, %t : f32 // Yield current iteration sum to next iteration %sum_iter or to %sum // if final iteration. affine.yield %sum_next : f32 @@ -213,8 +213,8 @@ ```mlir %res:2 = affine.for %i = 0 to 128 iter_args(%arg0 = %init0, %arg1 = %init1) -> (index, index) { - %y0 = addi %arg0, %c1 : index - %y1 = addi %arg1, %c2 : index + %y0 = arith.addi %arg0, %c1 : index + %y1 = arith.addi %arg1, %c2 : index affine.yield %y0, %y1 : index, index } ``` @@ -652,7 +652,7 @@ %0 = affine.parallel (%kx, %ky) = (0, 0) to (2, 2) reduce ("addf") { %1 = affine.load %D[%x + %kx, %y + %ky] : memref<100x100xf32> %2 = affine.load %K[%kx, %ky] : memref<3x3xf32> - %3 = mulf %1, %2 : f32 + %3 = arith.mulf %1, %2 : f32 affine.yield %3 : f32 } affine.store %0, O[%x, %y] : memref<98x98xf32> diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td --- a/mlir/include/mlir/Dialect/Affine/Passes.td +++ b/mlir/include/mlir/Dialect/Affine/Passes.td @@ -112,7 +112,7 @@ affine.for %i1 = 0 to 10 { affine.store %cf7, %m[%i0, %i1] : memref<10x10xf32> %v0 = affine.load %m[%i0, %i1] : memref<10x10xf32> - %v1 = addf %v0, %v0 : f32 + %v1 = arith.addf %v0, %v0 : f32 } } return %m : memref<10x10xf32> @@ -129,7 +129,7 @@ affine.for %arg0 = 0 to 10 { affine.for %arg1 = 0 to 10 { affine.store %cst, %0[%arg0, %arg1] : memref<10x10xf32> - %1 = addf %cst, %cst : f32 + %1 = arith.addf %cst, %cst : f32 } } return %0 : memref<10x10xf32> diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h --- a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h @@ -10,6 +10,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/VectorInterfaces.h" @@ -33,6 +34,64 @@ #define GET_OP_CLASSES #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.h.inc" +namespace mlir { +namespace arith { + +/// Specialization of `arith.const` op that returns an integer value. +class ConstantIntOp : public arith::ConstantOp { +public: + using arith::ConstantOp::ConstantOp; + + /// Build a constant int op that produces an integer of the specified width. + static void build(OpBuilder &builder, OperationState &result, int64_t value, + unsigned width); + + /// Build a constant int op that produces an integer of the specified type, + /// which must be an integer type. + static void build(OpBuilder &builder, OperationState &result, int64_t value, + Type type); + + inline int64_t value() { + return arith::ConstantOp::value().cast().getInt(); + } + + static bool classof(Operation *op); +}; + +/// Specialization of `arith.const` op that returns a floating point value. +class ConstantFloatOp : public arith::ConstantOp { +public: + using arith::ConstantOp::ConstantOp; + + /// Build a constant float op that produces a float of the specified type. + static void build(OpBuilder &builder, OperationState &result, + const APFloat &value, FloatType type); + + inline APFloat value() { + return arith::ConstantOp::value().cast().getValue(); + } + + static bool classof(Operation *op); +}; + +/// Specialization of `arith.const` op that returns an integer of index type. +class ConstantIndexOp : public arith::ConstantOp { +public: + using arith::ConstantOp::ConstantOp; + + /// Build a constant int op that produces an index. + static void build(OpBuilder &builder, OperationState &result, int64_t value); + + inline int64_t value() { + return arith::ConstantOp::value().cast().getInt(); + } + + static bool classof(Operation *op); +}; + +} // end namespace arith +} // end namespace mlir + //===----------------------------------------------------------------------===// // Utility Functions //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td @@ -20,6 +20,8 @@ ops, bitwise and shift ops, cast ops, and compare ops. Operations in this dialect also accept vectors and tensors of integers or floats. }]; + + let hasConstantMaterializer = 1; } // The predicate indicates the type of the comparison to perform: diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -13,6 +13,7 @@ include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/VectorInterfaces.td" +include "mlir/IR/OpAsmInterface.td" // Base class for Arithmetic dialect ops. Ops in this dialect have no side // effects and can be applied element-wise to vectors and tensors. @@ -119,8 +120,10 @@ //===----------------------------------------------------------------------===// def Arith_ConstantOp : Op, + TypesMatchWith< + "result and attribute have the same type", "value", "result", "$_self">]> { let summary = "integer or floating point constant"; let description = [{ @@ -149,6 +152,12 @@ [{ build($_builder, $_state, type, value); }]>, ]; + let extraClassDeclaration = [{ + /// Whether the constant op can be constructed with a particular value and + /// type. + static bool isBuildableWith(Attribute value, Type type); + }]; + let hasFolder = 1; let assemblyFormat = "attr-dict $value"; } @@ -717,10 +726,10 @@ ```mlir %1 = arith.const 21 : i5 // %1 is 0b10101 - %2 = trunci %1 : i5 to i4 // %2 is 0b0101 - %3 = trunci %1 : i5 to i3 // %3 is 0b101 + %2 = arith.trunci %1 : i5 to i4 // %2 is 0b0101 + %3 = arith.trunci %1 : i5 to i3 // %3 is 0b101 - %5 = trunci %0 : vector<2 x i32> to vector<2 x i16> + %5 = arith.trunci %0 : vector<2 x i32> to vector<2 x i16> ``` }]; @@ -803,7 +812,14 @@ // IndexCastOp //===----------------------------------------------------------------------===// -def Arith_IndexCastOp : Arith_IToICastOp<"index_cast"> { +// Index cast can convert between memrefs of signless integers and indices too. +def IndexCastTypeConstraint : TypeConstraint.predicate]>, + "signless-integer-like or memref of signless-integer">; + +def Arith_IndexCastOp : Arith_CastOp<"index_cast", IndexCastTypeConstraint, + IndexCastTypeConstraint> { let summary = "cast between index and integer types"; let description = [{ Casts between scalar or vector integers and corresponding 'index' scalar or @@ -820,8 +836,15 @@ // BitcastOp //===----------------------------------------------------------------------===// -def Arith_BitcastOp : Arith_CastOp<"bitcast", SignlessIntegerOrFloatLike, - SignlessIntegerOrFloatLike> { +// Bitcast can convert between memrefs of signless integers, indices, and +// floats too. +def BitcastTypeConstraint : TypeConstraint.predicate]>, + "signless-integer-or-float-like or memref of signless-integer or float">; + +def Arith_BitcastOp : Arith_CastOp<"bitcast", BitcastTypeConstraint, + BitcastTypeConstraint> { let summary = "bitcast between values of equal bit width"; let description = [{ Bitcast an integer or floating point value to an integer or floating point @@ -927,10 +950,10 @@ let extraClassDeclaration = [{ static StringRef getPredicateAttrName() { return "predicate"; } - static CmpIPredicate getPredicateByName(StringRef name); + static arith::CmpIPredicate getPredicateByName(StringRef name); - CmpIPredicate getPredicate() { - return (CmpIPredicate) (*this)->getAttrOfType( + arith::CmpIPredicate getPredicate() { + return (arith::CmpIPredicate) (*this)->getAttrOfType( getPredicateAttrName()).getInt(); } }]; @@ -983,10 +1006,10 @@ let extraClassDeclaration = [{ static StringRef getPredicateAttrName() { return "predicate"; } - static CmpFPredicate getPredicateByName(StringRef name); + static arith::CmpFPredicate getPredicateByName(StringRef name); - CmpFPredicate getPredicate() { - return (CmpFPredicate) (*this)->getAttrOfType( + arith::CmpFPredicate getPredicate() { + return (arith::CmpFPredicate) (*this)->getAttrOfType( getPredicateAttrName()).getInt(); } }]; diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td --- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td @@ -15,7 +15,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" -include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td" +include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td" include "mlir/Dialect/ArmSVE/ArmSVEOpBase.td" //===----------------------------------------------------------------------===// @@ -460,24 +460,24 @@ ``` }]; let arguments = (ins - CmpFPredicateAttr:$predicate, + Arith_CmpFPredicateAttr:$predicate, ScalableVectorOf<[AnyFloat]>:$lhs, ScalableVectorOf<[AnyFloat]>:$rhs // TODO: This should support a simple scalar ); let results = (outs ScalableVectorOf<[I1]>:$result); let builders = [ - OpBuilder<(ins "CmpFPredicate":$predicate, "Value":$lhs, + OpBuilder<(ins "arith::CmpFPredicate":$predicate, "Value":$lhs, "Value":$rhs), [{ buildScalableCmpFOp($_builder, $_state, predicate, lhs, rhs); }]>]; let extraClassDeclaration = [{ static StringRef getPredicateAttrName() { return "predicate"; } - static CmpFPredicate getPredicateByName(StringRef name); + static arith::CmpFPredicate getPredicateByName(StringRef name); - CmpFPredicate getPredicate() { - return (CmpFPredicate)(*this)->getAttrOfType( + arith::CmpFPredicate getPredicate() { + return (arith::CmpFPredicate) (*this)->getAttrOfType( getPredicateAttrName()).getInt(); } }]; @@ -520,24 +520,24 @@ }]; let arguments = (ins - CmpIPredicateAttr:$predicate, + Arith_CmpIPredicateAttr:$predicate, ScalableVectorOf<[I8, I16, I32, I64]>:$lhs, ScalableVectorOf<[I8, I16, I32, I64]>:$rhs ); let results = (outs ScalableVectorOf<[I1]>:$result); let builders = [ - OpBuilder<(ins "CmpIPredicate":$predicate, "Value":$lhs, + OpBuilder<(ins "arith::CmpIPredicate":$predicate, "Value":$lhs, "Value":$rhs), [{ buildScalableCmpIOp($_builder, $_state, predicate, lhs, rhs); }]>]; let extraClassDeclaration = [{ static StringRef getPredicateAttrName() { return "predicate"; } - static CmpIPredicate getPredicateByName(StringRef name); + static arith::CmpIPredicate getPredicateByName(StringRef name); - CmpIPredicate getPredicate() { - return (CmpIPredicate)(*this)->getAttrOfType( + arith::CmpIPredicate getPredicate() { + return (arith::CmpIPredicate) (*this)->getAttrOfType( getPredicateAttrName()).getInt(); } }]; diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td --- a/mlir/include/mlir/Dialect/Async/Passes.td +++ b/mlir/include/mlir/Dialect/Async/Passes.td @@ -32,7 +32,11 @@ "The target block size for sharding parallel operation."> ]; - let dependentDialects = ["async::AsyncDialect", "scf::SCFDialect"]; + let dependentDialects = [ + "arith::ArithmeticDialect", + "async::AsyncDialect", + "scf::SCFDialect" + ]; } def AsyncToAsyncRuntime : Pass<"async-to-async-runtime", "ModuleOp"> { diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -607,7 +607,7 @@ %1 = "gpu.all_reduce"(%0) ({}) { op = "add" } : (f32) -> (f32) %2 = "gpu.all_reduce"(%0) ({ ^bb(%lhs : f32, %rhs : f32): - %sum = addf %lhs, %rhs : f32 + %sum = arith.addf %lhs, %rhs : f32 "gpu.yield"(%sum) : (f32) -> () }) : (f32) -> (f32) ``` diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -33,8 +33,12 @@ }]; let cppNamespace = "::mlir::linalg"; let dependentDialects = [ - "AffineDialect", "math::MathDialect", "memref::MemRefDialect", - "StandardOpsDialect", "tensor::TensorDialect" + "arith::ArithmeticDialect", + "AffineDialect", + "math::MathDialect", + "memref::MemRefDialect", + "StandardOpsDialect", + "tensor::TensorDialect", ]; let hasCanonicalizer = 1; let hasOperationAttrVerify = 1; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -561,8 +561,8 @@ outs(%C : memref) {other-optional-attributes} { ^bb0(%a: f32, %b: f32, %c: f32) : - %d = mulf %a, %b: f32 - %e = addf %c, %d: f32 + %d = arith.mulf %a, %b: f32 + %e = arith.addf %c, %d: f32 linalg.yield %e : f32 } ``` @@ -584,8 +584,8 @@ %a = load %A[%m, %k] : memref %b = load %B[%k, %n] : memref %c = load %C[%m, %n] : memref - %d = mulf %a, %b: f32 - %e = addf %c, %d: f32 + %d = arith.mulf %a, %b: f32 + %e = arith.addf %c, %d: f32 store %e, %C[%m, %n] : memref } } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h @@ -10,6 +10,7 @@ #define MLIR_DIALECT_LINALG_LINALGTYPES_H_ #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -141,6 +141,7 @@ let summary = "Bufferize the linalg dialect"; let constructor = "mlir::createLinalgBufferizePass()"; let dependentDialects = [ + "arith::ArithmeticDialect", "linalg::LinalgDialect", "AffineDialect", "memref::MemRefDialect" diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -271,7 +271,7 @@ /// to /// /// %iv = %lb + %procId * %step - /// %cond = cmpi "slt", %iv, %ub + /// %cond = arith.cmpi "slt", %iv, %ub /// scf.if %cond { /// ... /// } diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -156,7 +156,7 @@ omp.wsloop (%i1, %i2) : index = (%c0, %c0) to (%c10, %c10) step (%c1, %c1) { %a = load %arrA[%i1, %i2] : memref %b = load %arrB[%i1, %i2] : memref - %sum = addf %a, %b : f32 + %sum = arith.addf %a, %b : f32 store %sum, %arrC[%i1, %i2] : memref omp.yield } diff --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td --- a/mlir/include/mlir/Dialect/SCF/Passes.td +++ b/mlir/include/mlir/Dialect/SCF/Passes.td @@ -69,7 +69,7 @@ "Perform tiling with fixed upper bound with inbound check " "inside the internal loops"> ]; - let dependentDialects = ["AffineDialect"]; + let dependentDialects = ["arith::ArithmeticDialect", "AffineDialect"]; } def SCFForLoopRangeFolding @@ -94,18 +94,18 @@ ```mlir # Before: scf.for %i = %c0 to %arg1 step %c1 { - %0 = addi %arg2, %arg2 : i32 + %0 = arith.addi %arg2, %arg2 : i32 memref.store %0, %arg0[%i] : memref } # After: %0 = scf.while (%i = %c0) : (index) -> index { - %1 = cmpi slt, %i, %arg1 : index + %1 = arith.cmpi slt, %i, %arg1 : index scf.condition(%1) %i : index } do { ^bb0(%i: index): // no predecessors - %1 = addi %i, %c1 : index - %2 = addi %arg2, %arg2 : i32 + %1 = arith.addi %i, %c1 : index + %2 = arith.addi %arg2, %arg2 : i32 memref.store %2, %arg0[%i] : memref scf.yield %1 : index } diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -170,7 +170,7 @@ %sum = scf.for %iv = %lb to %ub step %step iter_args(%sum_iter = %sum_0) -> (f32) { %t = load %buffer[%iv] : memref<1024xf32> - %sum_next = addf %sum_iter, %t : f32 + %sum_next = arith.addf %sum_iter, %t : f32 // Yield current iteration sum to next iteration %sum_iter or to %sum // if final iteration. scf.yield %sum_next : f32 @@ -194,9 +194,9 @@ %sum = scf.for %iv = %lb to %ub step %step iter_args(%sum_iter = %sum_0) -> (f32) { %t = load %buffer[%iv] : memref<1024xf32> - %cond = cmpf "ugt", %t, %c0 : f32 + %cond = arith.cmpf "ugt", %t, %c0 : f32 %sum_next = scf.if %cond -> (f32) { - %new_sum = addf %sum_iter, %t : f32 + %new_sum = arith.addf %sum_iter, %t : f32 scf.yield %new_sum : f32 } else { scf.yield %sum_iter : f32 @@ -451,7 +451,7 @@ %elem_to_reduce = load %buffer[%iv] : memref<100xf32> scf.reduce(%elem_to_reduce) : f32 { ^bb0(%lhs : f32, %rhs: f32): - %res = addf %lhs, %rhs : f32 + %res = arith.addf %lhs, %rhs : f32 scf.reduce.return %res : f32 } } @@ -519,7 +519,7 @@ %operand = constant 1.0 : f32 scf.reduce(%operand) : f32 { ^bb0(%lhs : f32, %rhs: f32): - %res = addf %lhs, %rhs : f32 + %res = arith.addf %lhs, %rhs : f32 scf.reduce.return %res : f32 } ``` diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -43,8 +43,8 @@ ins(%arga, %argb: tensor, tensor) outs(%argx: tensor) { ^bb(%a: f64, %b: f64, %x: f64): - %0 = mulf %a, %b : f64 - %1 = addf %x, %0 : f64 + %0 = arith.mulf %a, %b : f64 + %1 = arith.addf %x, %0 : f64 linalg.yield %1 : f64 } -> tensor return %0 : tensor diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_STANDARDOPS_IR_OPS_H #define MLIR_DIALECT_STANDARDOPS_IR_OPS_H +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" @@ -41,77 +42,15 @@ #include "mlir/Dialect/StandardOps/IR/OpsDialect.h.inc" namespace mlir { -/// This is a refinement of the "constant" op for the case where it is -/// returning a float value of FloatType. -/// -/// %1 = "std.constant"(){value: 42.0} : bf16 -/// -class ConstantFloatOp : public ConstantOp { -public: - using ConstantOp::ConstantOp; - - /// Builds a constant float op producing a float of the specified type. - static void build(OpBuilder &builder, OperationState &result, - const APFloat &value, FloatType type); - - APFloat getValue() { - return (*this)->getAttrOfType("value").getValue(); - } - - static bool classof(Operation *op); -}; - -/// This is a refinement of the "constant" op for the case where it is -/// returning an integer value of IntegerType. -/// -/// %1 = "std.constant"(){value: 42} : i32 -/// -class ConstantIntOp : public ConstantOp { -public: - using ConstantOp::ConstantOp; - /// Build a constant int op producing an integer of the specified width. - static void build(OpBuilder &builder, OperationState &result, int64_t value, - unsigned width); - - /// Build a constant int op producing an integer with the specified type, - /// which must be an integer type. - static void build(OpBuilder &builder, OperationState &result, int64_t value, - Type type); - - int64_t getValue() { - return (*this)->getAttrOfType("value").getInt(); - } - - static bool classof(Operation *op); -}; - -/// This is a refinement of the "constant" op for the case where it is -/// returning an integer value of Index type. -/// -/// %1 = "std.constant"(){value: 99} : () -> index -/// -class ConstantIndexOp : public ConstantOp { -public: - using ConstantOp::ConstantOp; - - /// Build a constant int op producing an index. - static void build(OpBuilder &builder, OperationState &result, int64_t value); - - int64_t getValue() { - return (*this)->getAttrOfType("value").getInt(); - } - - static bool classof(Operation *op); -}; /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer /// comparison predicates. -bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs, +bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs); /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point /// comparison predicates. -bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs, +bool applyCmpPredicate(arith::CmpFPredicate predicate, const APFloat &lhs, const APFloat &rhs); /// Returns the identity value attribute associated with an AtomicRMWKind op. diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -25,6 +25,7 @@ def StandardOps_Dialect : Dialect { let name = "std"; let cppNamespace = "::mlir"; + let dependentDialects = ["arith::ArithmeticDialect"]; let hasConstantMaterializer = 1; } @@ -182,138 +183,6 @@ [DeclareOpInterfaceMethods])>, Arguments<(ins FloatLike:$a, FloatLike:$b, FloatLike:$c)>; -//===----------------------------------------------------------------------===// -// AbsFOp -//===----------------------------------------------------------------------===// - -def AbsFOp : FloatUnaryOp<"absf"> { - let summary = "floating point absolute-value operation"; - let description = [{ - The `absf` operation computes the absolute value. It takes one operand and - returns one result of the same type. This type may be a float scalar type, - a vector whose element type is float, or a tensor of floats. - - Example: - - ```mlir - // Scalar absolute value. - %a = absf %b : f64 - - // SIMD vector element-wise absolute value. - %f = absf %g : vector<4xf32> - - // Tensor element-wise absolute value. - %x = absf %y : tensor<4x?xf8> - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// AddFOp -//===----------------------------------------------------------------------===// - -def AddFOp : FloatBinaryOp<"addf"> { - let summary = "floating point addition operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.addf` ssa-use `,` ssa-use `:` type - ``` - - The `addf` operation takes two operands and returns one result, each of - these is required to be the same type. This type may be a floating point - scalar type, a vector whose element type is a floating point type, or a - floating point tensor. - - Example: - - ```mlir - // Scalar addition. - %a = addf %b, %c : f64 - - // SIMD vector addition, e.g. for Intel SSE. - %f = addf %g, %h : vector<4xf32> - - // Tensor addition. - %x = addf %y, %z : tensor<4x?xbf16> - ``` - - TODO: In the distant future, this will accept optional attributes for fast - math, contraction, rounding mode, and other controls. - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// AddIOp -//===----------------------------------------------------------------------===// - -def AddIOp : IntBinaryOp<"addi", [Commutative]> { - let summary = "integer addition operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.addi` ssa-use `,` ssa-use `:` type - ``` - - The `addi` operation takes two operands and returns one result, each of - these is required to be the same type. This type may be an integer scalar - type, a vector whose element type is integer, or a tensor of integers. It - has no standard attributes. - - Example: - - ```mlir - // Scalar addition. - %a = addi %b, %c : i64 - - // SIMD vector element-wise addition, e.g. for Intel SSE. - %f = addi %g, %h : vector<4xi32> - - // Tensor element-wise addition. - %x = addi %y, %z : tensor<4x?xi8> - ``` - }]; - let hasFolder = 1; - let hasCanonicalizer = 1; -} - -//===----------------------------------------------------------------------===// -// AndOp -//===----------------------------------------------------------------------===// - -def AndOp : IntBinaryOp<"and", [Commutative]> { - let summary = "integer binary and"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.and` ssa-use `,` ssa-use `:` type - ``` - - The `and` operation takes two operands and returns one result, each of these - is required to be the same type. This type may be an integer scalar type, a - vector whose element type is integer, or a tensor of integers. It has no - standard attributes. - - Example: - - ```mlir - // Scalar integer bitwise and. - %a = and %b, %c : i64 - - // SIMD vector element-wise bitwise integer and. - %f = and %g, %h : vector<4xi32> - - // Tensor element-wise bitwise integer and. - %x = and %y, %z : tensor<4x?xi8> - ``` - }]; - let hasFolder = 1; -} - //===----------------------------------------------------------------------===// // AssertOp //===----------------------------------------------------------------------===// @@ -413,7 +282,7 @@ %x = generic_atomic_rmw %I[%i] : memref<10xf32> { ^bb0(%current_value : f32): %c1 = constant 1.0 : f32 - %inc = addf %c1, %current_value : f32 + %inc = arith.addf %c1, %current_value : f32 atomic_yield %inc : f32 } ``` @@ -456,32 +325,6 @@ let assemblyFormat = "$result attr-dict `:` type($result)"; } -//===----------------------------------------------------------------------===// -// BitcastOp -//===----------------------------------------------------------------------===// - -def BitcastOp : ArithmeticCastOp<"bitcast"> { - let summary = "bitcast between values of equal bit width"; - let description = [{ - Bitcast an integer or floating point value to an integer or floating point - value of equal bit width. When operating on vectors, casts elementwise. - - Note that this implements a logical bitcast independent of target - endianness. This allows constant folding without target information and is - consitent with the bitcast constant folders in LLVM (see - https://github.com/llvm/llvm-project/blob/18c19414eb/llvm/lib/IR/ConstantFold.cpp#L168) - For targets where the source and target type have the same endianness (which - is the standard), this cast will also change no bits at runtime, but it may - still require an operation, for example if the machine has different - floating point and integer register files. For targets that have a different - endianness for the source and target types (e.g. float is big-endian and - integer is little-endian) a proper lowering would add operations to swap the - order of words in addition to the bitcast. - }]; - let hasFolder = 1; -} - - //===----------------------------------------------------------------------===// // BranchOp //===----------------------------------------------------------------------===// @@ -666,240 +509,6 @@ let assemblyFormat = "$callee `(` $operands `)` attr-dict `:` type($callee)"; } -//===----------------------------------------------------------------------===// -// CeilFOp -//===----------------------------------------------------------------------===// - -def CeilFOp : FloatUnaryOp<"ceilf"> { - let summary = "ceiling of the specified value"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.ceilf` ssa-use `:` type - ``` - - The `ceilf` operation computes the ceiling of a given value. It takes one - operand and returns one result of the same type. This type may be a float - scalar type, a vector whose element type is float, or a tensor of floats. - It has no standard attributes. - - Example: - - ```mlir - // Scalar ceiling value. - %a = ceilf %b : f64 - - // SIMD vector element-wise ceiling value. - %f = ceilf %g : vector<4xf32> - - // Tensor element-wise ceiling value. - %x = ceilf %y : tensor<4x?xf8> - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// FloorFOp -//===----------------------------------------------------------------------===// - -def FloorFOp : FloatUnaryOp<"floorf"> { - let summary = "floor of the specified value"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.floorf` ssa-use `:` type - ``` - - The `floorf` operation computes the floor of a given value. It takes one - operand and returns one result of the same type. This type may be a float - scalar type, a vector whose element type is float, or a tensor of floats. - It has no standard attributes. - - Example: - - ```mlir - // Scalar floor value. - %a = floorf %b : f64 - - // SIMD vector element-wise floor value. - %f = floorf %g : vector<4xf32> - - // Tensor element-wise floor value. - %x = floorf %y : tensor<4x?xf8> - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// CmpFOp -//===----------------------------------------------------------------------===// - -def CmpFOp : Std_Op<"cmpf", [NoSideEffect, SameTypeOperands, - DeclareOpInterfaceMethods, TypesMatchWith< - "result type has i1 element type and same shape as operands", - "lhs", "result", "getI1SameShape($_self)">] # ElementwiseMappable.traits> { - let summary = "floating-point comparison operation"; - let description = [{ - The `cmpf` operation compares its two operands according to the float - comparison rules and the predicate specified by the respective attribute. - The predicate defines the type of comparison: (un)orderedness, (in)equality - and signed less/greater than (or equal to) as well as predicates that are - always true or false. The operands must have the same type, and this type - must be a float type, or a vector or tensor thereof. The result is an i1, - or a vector/tensor thereof having the same shape as the inputs. Unlike cmpi, - the operands are always treated as signed. The u prefix indicates - *unordered* comparison, not unsigned comparison, so "une" means unordered or - not equal. For the sake of readability by humans, custom assembly form for - the operation uses a string-typed attribute for the predicate. The value of - this attribute corresponds to lower-cased name of the predicate constant, - e.g., "one" means "ordered not equal". The string representation of the - attribute is merely a syntactic sugar and is converted to an integer - attribute by the parser. - - Example: - - ```mlir - %r1 = cmpf "oeq" %0, %1 : f32 - %r2 = cmpf "ult" %0, %1 : tensor<42x42xf64> - %r3 = "std.cmpf"(%0, %1) {predicate: 0} : (f8, f8) -> i1 - ``` - }]; - - let arguments = (ins - CmpFPredicateAttr:$predicate, - FloatLike:$lhs, - FloatLike:$rhs - ); - let results = (outs BoolLike:$result); - - let builders = [ - OpBuilder<(ins "CmpFPredicate":$predicate, "Value":$lhs, - "Value":$rhs), [{ - ::buildCmpFOp($_builder, $_state, predicate, lhs, rhs); - }]>]; - - let extraClassDeclaration = [{ - static StringRef getPredicateAttrName() { return "predicate"; } - static CmpFPredicate getPredicateByName(StringRef name); - - CmpFPredicate getPredicate() { - return (CmpFPredicate)(*this)->getAttrOfType( - getPredicateAttrName()).getInt(); - } - }]; - - let verifier = [{ return success(); }]; - - let hasFolder = 1; - - let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)"; -} - -//===----------------------------------------------------------------------===// -// CmpIOp -//===----------------------------------------------------------------------===// - -def CmpIOp : Std_Op<"cmpi", [NoSideEffect, SameTypeOperands, - DeclareOpInterfaceMethods, TypesMatchWith< - "result type has i1 element type and same shape as operands", - "lhs", "result", "getI1SameShape($_self)">] # ElementwiseMappable.traits> { - let summary = "integer comparison operation"; - let description = [{ - The `cmpi` operation is a generic comparison for integer-like types. Its two - arguments can be integers, vectors or tensors thereof as long as their types - match. The operation produces an i1 for the former case, a vector or a - tensor of i1 with the same shape as inputs in the other cases. - - Its first argument is an attribute that defines which type of comparison is - performed. The following comparisons are supported: - - - equal (mnemonic: `"eq"`; integer value: `0`) - - not equal (mnemonic: `"ne"`; integer value: `1`) - - signed less than (mnemonic: `"slt"`; integer value: `2`) - - signed less than or equal (mnemonic: `"sle"`; integer value: `3`) - - signed greater than (mnemonic: `"sgt"`; integer value: `4`) - - signed greater than or equal (mnemonic: `"sge"`; integer value: `5`) - - unsigned less than (mnemonic: `"ult"`; integer value: `6`) - - unsigned less than or equal (mnemonic: `"ule"`; integer value: `7`) - - unsigned greater than (mnemonic: `"ugt"`; integer value: `8`) - - unsigned greater than or equal (mnemonic: `"uge"`; integer value: `9`) - - The result is `1` if the comparison is true and `0` otherwise. For vector or - tensor operands, the comparison is performed elementwise and the element of - the result indicates whether the comparison is true for the operand elements - with the same indices as those of the result. - - Note: while the custom assembly form uses strings, the actual underlying - attribute has integer type (or rather enum class in C++ code) as seen from - the generic assembly form. String literals are used to improve readability - of the IR by humans. - - This operation only applies to integer-like operands, but not floats. The - main reason being that comparison operations have diverging sets of - attributes: integers require sign specification while floats require various - floating point-related particularities, e.g., `-ffast-math` behavior, - IEEE754 compliance, etc - ([rationale](../Rationale/Rationale.md#splitting-floating-point-vs-integer-operations)). - The type of comparison is specified as attribute to avoid introducing ten - similar operations, taking into account that they are often implemented - using the same operation downstream - ([rationale](../Rationale/Rationale.md#specifying-comparison-kind-as-attribute)). The - separation between signed and unsigned order comparisons is necessary - because of integers being signless. The comparison operation must know how - to interpret values with the foremost bit being set: negatives in two's - complement or large positives - ([rationale](../Rationale/Rationale.md#specifying-sign-in-integer-comparison-operations)). - - Example: - - ```mlir - // Custom form of scalar "signed less than" comparison. - %x = cmpi "slt", %lhs, %rhs : i32 - - // Generic form of the same operation. - %x = "std.cmpi"(%lhs, %rhs) {predicate = 2 : i64} : (i32, i32) -> i1 - - // Custom form of vector equality comparison. - %x = cmpi "eq", %lhs, %rhs : vector<4xi64> - - // Generic form of the same operation. - %x = "std.cmpi"(%lhs, %rhs) {predicate = 0 : i64} - : (vector<4xi64>, vector<4xi64>) -> vector<4xi1> - ``` - }]; - - let arguments = (ins - CmpIPredicateAttr:$predicate, - SignlessIntegerLike:$lhs, - SignlessIntegerLike:$rhs - ); - let results = (outs BoolLike:$result); - - let builders = [ - OpBuilder<(ins "CmpIPredicate":$predicate, "Value":$lhs, - "Value":$rhs), [{ - ::buildCmpIOp($_builder, $_state, predicate, lhs, rhs); - }]>]; - - let extraClassDeclaration = [{ - static StringRef getPredicateAttrName() { return "predicate"; } - static CmpIPredicate getPredicateByName(StringRef name); - - CmpIPredicate getPredicate() { - return (CmpIPredicate)(*this)->getAttrOfType( - getPredicateAttrName()).getInt(); - } - }]; - - let verifier = [{ return success(); }]; - - let hasFolder = 1; - - let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)"; -} - //===----------------------------------------------------------------------===// // CondBranchOp //===----------------------------------------------------------------------===// @@ -1095,264 +704,111 @@ } //===----------------------------------------------------------------------===// -// CopySignOp +// MaxFOp //===----------------------------------------------------------------------===// -def CopySignOp : FloatBinaryOp<"copysign"> { - let summary = "A copysign operation"; +def MaxFOp : FloatBinaryOp<"maxf"> { + let summary = "floating-point maximum operation"; let description = [{ Syntax: ``` - operation ::= ssa-id `=` `std.copysign` ssa-use `,` ssa-use `:` type + operation ::= ssa-id `=` `maxf` ssa-use `,` ssa-use `:` type ``` - The `copysign` returns a value with the magnitude of the first operand and - the sign of the second operand. It takes two operands and returns one - result of the same type. This type may be a float scalar type, a vector - whose element type is float, or a tensor of floats. It has no standard - attributes. + Returns the maximum of the two arguments, treating -0.0 as less than +0.0. + If one of the arguments is NaN, then the result is also NaN. Example: ```mlir - // Scalar copysign value. - %a = copysign %b, %c : f64 - - // SIMD vector element-wise copysign value. - %f = copysign %g, %h : vector<4xf32> - - // Tensor element-wise copysign value. - %x = copysign %y, %z : tensor<4x?xf8> + // Scalar floating-point maximum. + %a = maxf %b, %c : f64 ``` }]; } //===----------------------------------------------------------------------===// -// DivFOp -//===----------------------------------------------------------------------===// - -def DivFOp : FloatBinaryOp<"divf"> { - let summary = "floating point division operation"; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// FmaFOp +// MaxSIOp //===----------------------------------------------------------------------===// -def FmaFOp : FloatTernaryOp<"fmaf"> { - let summary = "floating point fused multipy-add operation"; +def MaxSIOp : IntBinaryOp<"maxsi"> { + let summary = "signed integer maximum operation"; let description = [{ Syntax: ``` - operation ::= ssa-id `=` `std.fmaf` ssa-use `,` ssa-use `,` ssa-use `:` type + operation ::= ssa-id `=` `maxsi` ssa-use `,` ssa-use `:` type ``` - The `fmaf` operation takes three operands and returns one result, each of - these is required to be the same type. This type may be a floating point - scalar type, a vector whose element type is a floating point type, or a - floating point tensor. + Returns the larger of %a and %b comparing the values as signed integers. Example: ```mlir - // Scalar fused multiply-add: d = a*b + c - %d = fmaf %a, %b, %c : f64 - - // SIMD vector fused multiply-add, e.g. for Intel SSE. - %i = fmaf %f, %g, %h : vector<4xf32> - - // Tensor fused multiply-add. - %w = fmaf %x, %y, %z : tensor<4x?xbf16> + // Scalar signed integer maximum. + %a = maxsi %b, %c : i64 ``` - - The semantics of the operation correspond to those of the `llvm.fma` - [intrinsic](https://llvm.org/docs/LangRef.html#llvm-fma-intrinsic). In the - particular case of lowering to LLVM, this is guaranteed to lower - to the `llvm.fma.*` intrinsic. }]; } //===----------------------------------------------------------------------===// -// FPExtOp +// MaxUIOp //===----------------------------------------------------------------------===// -def FPExtOp : ArithmeticCastOp<"fpext"> { - let summary = "cast from floating-point to wider floating-point"; +def MaxUIOp : IntBinaryOp<"maxui"> { + let summary = "unsigned integer maximum operation"; let description = [{ - Cast a floating-point value to a larger floating-point-typed value. - The destination type must to be strictly wider than the source type. - When operating on vectors, casts elementwise. - }]; -} + Syntax: -//===----------------------------------------------------------------------===// -// FPToSIOp -//===----------------------------------------------------------------------===// + ``` + operation ::= ssa-id `=` `maxui` ssa-use `,` ssa-use `:` type + ``` -def FPToSIOp : ArithmeticCastOp<"fptosi"> { - let summary = "cast from floating-point type to integer type"; - let description = [{ - Cast from a value interpreted as floating-point to the nearest (rounding - towards zero) signed integer value. When operating on vectors, casts - elementwise. - }]; -} + Returns the larger of %a and %b comparing the values as unsigned integers. -//===----------------------------------------------------------------------===// -// FPToUIOp -//===----------------------------------------------------------------------===// + Example: -def FPToUIOp : ArithmeticCastOp<"fptoui"> { - let summary = "cast from floating-point type to integer type"; - let description = [{ - Cast from a value interpreted as floating-point to the nearest (rounding - towards zero) unsigned integer value. When operating on vectors, casts - elementwise. + ```mlir + // Scalar unsigned integer maximum. + %a = maxui %b, %c : i64 + ``` }]; } //===----------------------------------------------------------------------===// -// FPTruncOp +// MinFOp //===----------------------------------------------------------------------===// -def FPTruncOp : ArithmeticCastOp<"fptrunc"> { - let summary = "cast from floating-point to narrower floating-point"; +def MinFOp : FloatBinaryOp<"minf"> { + let summary = "floating-point minimum operation"; let description = [{ - Truncate a floating-point value to a smaller floating-point-typed value. - The destination type must be strictly narrower than the source type. - If the value cannot be exactly represented, it is rounded using the default - rounding mode. When operating on vectors, casts elementwise. - }]; + Syntax: - let hasFolder = 1; + ``` + operation ::= ssa-id `=` `minf` ssa-use `,` ssa-use `:` type + ``` + + Returns the minimum of the two arguments, treating -0.0 as less than +0.0. + If one of the arguments is NaN, then the result is also NaN. + + Example: + + ```mlir + // Scalar floating-point minimum. + %a = minf %b, %c : f64 + ``` + }]; } //===----------------------------------------------------------------------===// -// IndexCastOp +// MinSIOp //===----------------------------------------------------------------------===// -def IndexCastOp : ArithmeticCastOp<"index_cast"> { - let summary = "cast between index and integer types"; +def MinSIOp : IntBinaryOp<"minsi"> { + let summary = "signed integer minimum operation"; let description = [{ - Casts between scalar or vector integers and corresponding 'index' scalar or - vectors. Index is an integer of platform-specific bit width. If casting to - a wider integer, the value is sign-extended. If casting to a narrower - integer, the value is truncated. - }]; - - let hasFolder = 1; - let hasCanonicalizer = 1; -} - -//===----------------------------------------------------------------------===// -// MaxFOp -//===----------------------------------------------------------------------===// - -def MaxFOp : FloatBinaryOp<"maxf"> { - let summary = "floating-point maximum operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `maxf` ssa-use `,` ssa-use `:` type - ``` - - Returns the maximum of the two arguments, treating -0.0 as less than +0.0. - If one of the arguments is NaN, then the result is also NaN. - - Example: - - ```mlir - // Scalar floating-point maximum. - %a = maxf %b, %c : f64 - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// MaxSIOp -//===----------------------------------------------------------------------===// - -def MaxSIOp : IntBinaryOp<"maxsi"> { - let summary = "signed integer maximum operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `maxsi` ssa-use `,` ssa-use `:` type - ``` - - Returns the larger of %a and %b comparing the values as signed integers. - - Example: - - ```mlir - // Scalar signed integer maximum. - %a = maxsi %b, %c : i64 - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// MaxUIOp -//===----------------------------------------------------------------------===// - -def MaxUIOp : IntBinaryOp<"maxui"> { - let summary = "unsigned integer maximum operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `maxui` ssa-use `,` ssa-use `:` type - ``` - - Returns the larger of %a and %b comparing the values as unsigned integers. - - Example: - - ```mlir - // Scalar unsigned integer maximum. - %a = maxui %b, %c : i64 - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// MinFOp -//===----------------------------------------------------------------------===// - -def MinFOp : FloatBinaryOp<"minf"> { - let summary = "floating-point minimum operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `minf` ssa-use `,` ssa-use `:` type - ``` - - Returns the minimum of the two arguments, treating -0.0 as less than +0.0. - If one of the arguments is NaN, then the result is also NaN. - - Example: - - ```mlir - // Scalar floating-point minimum. - %a = minf %b, %c : f64 - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// MinSIOp -//===----------------------------------------------------------------------===// - -def MinSIOp : IntBinaryOp<"minsi"> { - let summary = "signed integer minimum operation"; - let description = [{ - Syntax: + Syntax: ``` operation ::= ssa-id `=` `minsi` ssa-use `,` ssa-use `:` type @@ -1393,119 +849,6 @@ }]; } -//===----------------------------------------------------------------------===// -// MulFOp -//===----------------------------------------------------------------------===// - -def MulFOp : FloatBinaryOp<"mulf"> { - let summary = "floating point multiplication operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.mulf` ssa-use `,` ssa-use `:` type - ``` - - The `mulf` operation takes two operands and returns one result, each of - these is required to be the same type. This type may be a floating point - scalar type, a vector whose element type is a floating point type, or a - floating point tensor. - - Example: - - ```mlir - // Scalar multiplication. - %a = mulf %b, %c : f64 - - // SIMD pointwise vector multiplication, e.g. for Intel SSE. - %f = mulf %g, %h : vector<4xf32> - - // Tensor pointwise multiplication. - %x = mulf %y, %z : tensor<4x?xbf16> - ``` - - TODO: In the distant future, this will accept optional attributes for fast - math, contraction, rounding mode, and other controls. - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// MulIOp -//===----------------------------------------------------------------------===// - -def MulIOp : IntBinaryOp<"muli", [Commutative]> { - let summary = "integer multiplication operation"; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// NegFOp -//===----------------------------------------------------------------------===// - -def NegFOp : FloatUnaryOp<"negf"> { - let summary = "floating point negation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `negf` ssa-use `:` type - ``` - - The `negf` operation computes the negation of a given value. It takes one - operand and returns one result of the same type. This type may be a float - scalar type, a vector whose element type is float, or a tensor of floats. - It has no standard attributes. - - Example: - - ```mlir - // Scalar negation value. - %a = negf %b : f64 - - // SIMD vector element-wise negation value. - %f = negf %g : vector<4xf32> - - // Tensor element-wise negation value. - %x = negf %y : tensor<4x?xf8> - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// OrOp -//===----------------------------------------------------------------------===// - -def OrOp : IntBinaryOp<"or", [Commutative]> { - let summary = "integer binary or"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `or` ssa-use `,` ssa-use `:` type - ``` - - The `or` operation takes two operands and returns one result, each of these - is required to be the same type. This type may be an integer scalar type, a - vector whose element type is integer, or a tensor of integers. It has no - standard attributes. - - Example: - - ```mlir - // Scalar integer bitwise or. - %a = or %b, %c : i64 - - // SIMD vector element-wise bitwise integer or. - %f = or %g, %h : vector<4xi32> - - // Tensor element-wise bitwise integer or. - %x = or %y, %z : tensor<4x?xi8> - ``` - }]; - let hasFolder = 1; -} - //===----------------------------------------------------------------------===// // RankOp //===----------------------------------------------------------------------===// @@ -1538,14 +881,6 @@ let assemblyFormat = "$memrefOrTensor attr-dict `:` type($memrefOrTensor)"; } -//===----------------------------------------------------------------------===// -// RemFOp -//===----------------------------------------------------------------------===// - -def RemFOp : FloatBinaryOp<"remf"> { - let summary = "floating point division remainder operation"; -} - //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// @@ -1641,236 +976,6 @@ let hasFolder = 1; } -//===----------------------------------------------------------------------===// -// ShiftLeftOp -//===----------------------------------------------------------------------===// - -def ShiftLeftOp : IntBinaryOp<"shift_left"> { - let summary = "integer left-shift"; - let description = [{ - The shift_left operation shifts an integer value to the left by a variable - amount. The low order bits are filled with zeros. - - Example: - - ```mlir - %1 = constant 5 : i8 // %1 is 0b00000101 - %2 = constant 3 : i8 - %3 = shift_left %1, %2 : (i8, i8) -> i8 // %3 is 0b00101000 - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// SignedDivIOp -//===----------------------------------------------------------------------===// - -def SignedDivIOp : IntBinaryOp<"divi_signed"> { - let summary = "signed integer division operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `divi_signed` ssa-use `,` ssa-use `:` type - ``` - - Signed integer division. Rounds towards zero. Treats the leading bit as - sign, i.e. `6 / -2 = -3`. - - Note: the semantics of division by zero or signed division overflow (minimum - value divided by -1) is TBD; do NOT assume any specific behavior. - - Example: - - ```mlir - // Scalar signed integer division. - %a = divi_signed %b, %c : i64 - - // SIMD vector element-wise division. - %f = divi_signed %g, %h : vector<4xi32> - - // Tensor element-wise integer division. - %x = divi_signed %y, %z : tensor<4x?xi8> - ``` - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// SignedFloorDivIOp -//===----------------------------------------------------------------------===// - -def SignedFloorDivIOp : IntBinaryOp<"floordivi_signed"> { - let summary = "signed floor integer division operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `floordivi_signed` ssa-use `,` ssa-use `:` type - ``` - - Signed integer division. Rounds towards negative infinity, i.e. `5 / -2 = -3`. - - Note: the semantics of division by zero or signed division overflow (minimum - value divided by -1) is TBD; do NOT assume any specific behavior. - - Example: - - ```mlir - // Scalar signed integer division. - %a = floordivi_signed %b, %c : i64 - - ``` - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// SignedCeilDivIOp -//===----------------------------------------------------------------------===// - -def SignedCeilDivIOp : IntBinaryOp<"ceildivi_signed"> { - let summary = "signed ceil integer division operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `ceildivi_signed` ssa-use `,` ssa-use `:` type - ``` - - Signed integer division. Rounds towards positive infinity, i.e. `7 / -2 = -3`. - - Note: the semantics of division by zero or signed division overflow (minimum - value divided by -1) is TBD; do NOT assume any specific behavior. - - Example: - - ```mlir - // Scalar signed integer division. - %a = ceildivi_signed %b, %c : i64 - ``` - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// SignedRemIOp -//===----------------------------------------------------------------------===// - -def SignedRemIOp : IntBinaryOp<"remi_signed"> { - let summary = "signed integer division remainder operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.remi_signed` ssa-use `,` ssa-use `:` type - ``` - - Signed integer division remainder. Treats the leading bit as sign, i.e. `6 % - -2 = 0`. - - Note: the semantics of division by zero is TBD; do NOT assume any specific - behavior. - - Example: - - ```mlir - // Scalar signed integer division remainder. - %a = remi_signed %b, %c : i64 - - // SIMD vector element-wise division remainder. - %f = remi_signed %g, %h : vector<4xi32> - - // Tensor element-wise integer division remainder. - %x = remi_signed %y, %z : tensor<4x?xi8> - ``` - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// SignedShiftRightOp -//===----------------------------------------------------------------------===// - -def SignedShiftRightOp : IntBinaryOp<"shift_right_signed"> { - let summary = "signed integer right-shift"; - let description = [{ - The shift_right_signed operation shifts an integer value to the right by - a variable amount. The integer is interpreted as signed. The high order - bits in the output are filled with copies of the most-significant bit - of the shifted value (which means that the sign of the value is preserved). - - Example: - - ```mlir - %1 = constant 160 : i8 // %1 is 0b10100000 - %2 = constant 3 : i8 - %3 = shift_right_signed %1, %2 : (i8, i8) -> i8 // %3 is 0b11110100 - %4 = constant 96 : i8 // %4 is 0b01100000 - %5 = shift_right_signed %4, %2 : (i8, i8) -> i8 // %5 is 0b00001100 - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// SignExtendIOp -//===----------------------------------------------------------------------===// - -def SignExtendIOp : Std_Op<"sexti", [NoSideEffect, - DeclareOpInterfaceMethods] # - ElementwiseMappable.traits> { - let summary = "integer sign extension operation"; - let description = [{ - The integer sign extension operation takes an integer input of - width M and an integer destination type of width N. The destination - bit-width must be larger than the input bit-width (N > M). - The top-most (N - M) bits of the output are filled with copies - of the most-significant bit of the input. - - Example: - - ```mlir - %1 = constant 5 : i3 // %1 is 0b101 - %2 = sexti %1 : i3 to i6 // %2 is 0b111101 - %3 = constant 2 : i3 // %3 is 0b010 - %4 = sexti %3 : i3 to i6 // %4 is 0b000010 - - %5 = sexti %0 : vector<2 x i32> to vector<2 x i64> - ``` - }]; - - let arguments = (ins SignlessIntegerLike:$value); - let results = (outs SignlessIntegerLike); - - let builders = [ - OpBuilder<(ins "Value":$value, "Type":$destType), [{ - $_state.addOperands(value); - $_state.addTypes(destType); - }]>]; - - let parser = [{ - return impl::parseCastOp(parser, result); - }]; - let printer = [{ - return printStandardCastOp(this->getOperation(), p); - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// SIToFPOp -//===----------------------------------------------------------------------===// - -def SIToFPOp : ArithmeticCastOp<"sitofp"> { - let summary = "cast from integer type to floating-point"; - let description = [{ - Cast from a value interpreted as a signed integer to the corresponding - floating-point value. If the value cannot be exactly represented, it is - rounded using the default rounding mode. When operating on vectors, casts - elementwise. - }]; -} - //===----------------------------------------------------------------------===// // SplatOp //===----------------------------------------------------------------------===// @@ -1918,25 +1023,6 @@ let assemblyFormat = "$input attr-dict `:` type($aggregate)"; } -//===----------------------------------------------------------------------===// -// SubFOp -//===----------------------------------------------------------------------===// - -def SubFOp : FloatBinaryOp<"subf"> { - let summary = "floating point subtraction operation"; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// SubIOp -//===----------------------------------------------------------------------===// - -def SubIOp : IntBinaryOp<"subi"> { - let summary = "integer subtraction operation"; - let hasFolder = 1; - let hasCanonicalizer = 1; -} - //===----------------------------------------------------------------------===// // SwitchOp //===----------------------------------------------------------------------===// @@ -2025,225 +1111,4 @@ let hasCanonicalizer = 1; } -//===----------------------------------------------------------------------===// -// TruncateIOp -//===----------------------------------------------------------------------===// - -def TruncateIOp : Std_Op<"trunci", [NoSideEffect, - DeclareOpInterfaceMethods] # - ElementwiseMappable.traits> { - let summary = "integer truncation operation"; - let description = [{ - The integer truncation operation takes an integer input of - width M and an integer destination type of width N. The destination - bit-width must be smaller than the input bit-width (N < M). - The top-most (N - M) bits of the input are discarded. - - Example: - - ```mlir - %1 = constant 21 : i5 // %1 is 0b10101 - %2 = trunci %1 : i5 to i4 // %2 is 0b0101 - %3 = trunci %1 : i5 to i3 // %3 is 0b101 - - %5 = trunci %0 : vector<2 x i32> to vector<2 x i16> - ``` - }]; - - let arguments = (ins SignlessIntegerLike:$value); - let results = (outs SignlessIntegerLike); - - let builders = [ - OpBuilder<(ins "Value":$value, "Type":$destType), [{ - $_state.addOperands(value); - $_state.addTypes(destType); - }]>]; - - let parser = [{ - return impl::parseCastOp(parser, result); - }]; - let printer = [{ - return printStandardCastOp(this->getOperation(), p); - }]; - - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// UIToFPOp -//===----------------------------------------------------------------------===// - -def UIToFPOp : ArithmeticCastOp<"uitofp"> { - let summary = "cast from unsigned integer type to floating-point"; - let description = [{ - Cast from a value interpreted as unsigned integer to the corresponding - floating-point value. If the value cannot be exactly represented, it is - rounded using the default rounding mode. When operating on vectors, casts - elementwise. - }]; -} - -//===----------------------------------------------------------------------===// -// UnsignedDivIOp -//===----------------------------------------------------------------------===// - -def UnsignedDivIOp : IntBinaryOp<"divi_unsigned"> { - let summary = "unsigned integer division operation"; - let description = [{ - Syntax: - ``` - operation ::= ssa-id `=` `std.divi_unsigned` ssa-use `,` ssa-use `:` type - ``` - - Unsigned integer division. Rounds towards zero. Treats the leading bit as - the most significant, i.e. for `i16` given two's complement representation, - `6 / -2 = 6 / (2^16 - 2) = 0`. - - Note: the semantics of division by zero is TBD; do NOT assume any specific - behavior. - - Example: - - ```mlir - // Scalar unsigned integer division. - %a = divi_unsigned %b, %c : i64 - - // SIMD vector element-wise division. - %f = divi_unsigned %g, %h : vector<4xi32> - - // Tensor element-wise integer division. - %x = divi_unsigned %y, %z : tensor<4x?xi8> - ``` - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// UnsignedRemIOp -//===----------------------------------------------------------------------===// - -def UnsignedRemIOp : IntBinaryOp<"remi_unsigned"> { - let summary = "unsigned integer division remainder operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.remi_unsigned` ssa-use `,` ssa-use `:` type - ``` - - Unsigned integer division remainder. Treats the leading bit as the most - significant, i.e. for `i16`, `6 % -2 = 6 % (2^16 - 2) = 6`. - - Note: the semantics of division by zero is TBD; do NOT assume any specific - behavior. - - Example: - - ```mlir - // Scalar unsigned integer division remainder. - %a = remi_unsigned %b, %c : i64 - - // SIMD vector element-wise division remainder. - %f = remi_unsigned %g, %h : vector<4xi32> - - // Tensor element-wise integer division remainder. - %x = remi_unsigned %y, %z : tensor<4x?xi8> - ``` - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// UnsignedShiftRightOp -//===----------------------------------------------------------------------===// - -def UnsignedShiftRightOp : IntBinaryOp<"shift_right_unsigned"> { - let summary = "unsigned integer right-shift"; - let description = [{ - The shift_right_unsigned operation shifts an integer value to the right by - a variable amount. The integer is interpreted as unsigned. The high order - bits are always filled with zeros. - - Example: - - ```mlir - %1 = constant 160 : i8 // %1 is 0b10100000 - %2 = constant 3 : i8 - %3 = shift_right_unsigned %1, %2 : (i8, i8) -> i8 // %3 is 0b00010100 - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// XOrOp -//===----------------------------------------------------------------------===// - -def XOrOp : IntBinaryOp<"xor", [Commutative]> { - let summary = "integer binary xor"; - let description = [{ - The `xor` operation takes two operands and returns one result, each of these - is required to be the same type. This type may be an integer scalar type, a - vector whose element type is integer, or a tensor of integers. It has no - standard attributes. - - Example: - - ```mlir - // Scalar integer bitwise xor. - %a = xor %b, %c : i64 - - // SIMD vector element-wise bitwise integer xor. - %f = xor %g, %h : vector<4xi32> - - // Tensor element-wise bitwise integer xor. - %x = xor %y, %z : tensor<4x?xi8> - ``` - }]; - let hasFolder = 1; - let hasCanonicalizer = 1; -} - -//===----------------------------------------------------------------------===// -// ZeroExtendIOp -//===----------------------------------------------------------------------===// - -def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect, - DeclareOpInterfaceMethods] # - ElementwiseMappable.traits> { - let summary = "integer zero extension operation"; - let description = [{ - The integer zero extension operation takes an integer input of - width M and an integer destination type of width N. The destination - bit-width must be larger than the input bit-width (N > M). - The top-most (N - M) bits of the output are filled with zeros. - - Example: - - ```mlir - %1 = constant 5 : i3 // %1 is 0b101 - %2 = zexti %1 : i3 to i6 // %2 is 0b000101 - %3 = constant 2 : i3 // %3 is 0b010 - %4 = zexti %3 : i3 to i6 // %4 is 0b000010 - - %5 = zexti %0 : vector<2 x i32> to vector<2 x i64> - ``` - }]; - - let arguments = (ins SignlessIntegerLike:$value); - let results = (outs SignlessIntegerLike); - - let builders = [ - OpBuilder<(ins "Value":$value, "Type":$destType), [{ - $_state.addOperands(value); - $_state.addTypes(destType); - }]>]; - - let parser = [{ - return impl::parseCastOp(parser, result); - }]; - let printer = [{ - return printStandardCastOp(this->getOperation(), p); - }]; -} - #endif // STANDARD_OPS diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td b/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td @@ -36,50 +36,4 @@ let cppNamespace = "::mlir"; } -// The predicate indicates the type of the comparison to perform: -// (un)orderedness, (in)equality and less/greater than (or equal to) as -// well as predicates that are always true or false. -def CMPF_P_FALSE : I64EnumAttrCase<"AlwaysFalse", 0, "false">; -def CMPF_P_OEQ : I64EnumAttrCase<"OEQ", 1, "oeq">; -def CMPF_P_OGT : I64EnumAttrCase<"OGT", 2, "ogt">; -def CMPF_P_OGE : I64EnumAttrCase<"OGE", 3, "oge">; -def CMPF_P_OLT : I64EnumAttrCase<"OLT", 4, "olt">; -def CMPF_P_OLE : I64EnumAttrCase<"OLE", 5, "ole">; -def CMPF_P_ONE : I64EnumAttrCase<"ONE", 6, "one">; -def CMPF_P_ORD : I64EnumAttrCase<"ORD", 7, "ord">; -def CMPF_P_UEQ : I64EnumAttrCase<"UEQ", 8, "ueq">; -def CMPF_P_UGT : I64EnumAttrCase<"UGT", 9, "ugt">; -def CMPF_P_UGE : I64EnumAttrCase<"UGE", 10, "uge">; -def CMPF_P_ULT : I64EnumAttrCase<"ULT", 11, "ult">; -def CMPF_P_ULE : I64EnumAttrCase<"ULE", 12, "ule">; -def CMPF_P_UNE : I64EnumAttrCase<"UNE", 13, "une">; -def CMPF_P_UNO : I64EnumAttrCase<"UNO", 14, "uno">; -def CMPF_P_TRUE : I64EnumAttrCase<"AlwaysTrue", 15, "true">; - -def CmpFPredicateAttr : I64EnumAttr< - "CmpFPredicate", "", - [CMPF_P_FALSE, CMPF_P_OEQ, CMPF_P_OGT, CMPF_P_OGE, CMPF_P_OLT, CMPF_P_OLE, - CMPF_P_ONE, CMPF_P_ORD, CMPF_P_UEQ, CMPF_P_UGT, CMPF_P_UGE, CMPF_P_ULT, - CMPF_P_ULE, CMPF_P_UNE, CMPF_P_UNO, CMPF_P_TRUE]> { - let cppNamespace = "::mlir"; -} - -def CMPI_P_EQ : I64EnumAttrCase<"eq", 0>; -def CMPI_P_NE : I64EnumAttrCase<"ne", 1>; -def CMPI_P_SLT : I64EnumAttrCase<"slt", 2>; -def CMPI_P_SLE : I64EnumAttrCase<"sle", 3>; -def CMPI_P_SGT : I64EnumAttrCase<"sgt", 4>; -def CMPI_P_SGE : I64EnumAttrCase<"sge", 5>; -def CMPI_P_ULT : I64EnumAttrCase<"ult", 6>; -def CMPI_P_ULE : I64EnumAttrCase<"ule", 7>; -def CMPI_P_UGT : I64EnumAttrCase<"ugt", 8>; -def CMPI_P_UGE : I64EnumAttrCase<"uge", 9>; - -def CmpIPredicateAttr : I64EnumAttr< - "CmpIPredicate", "", - [CMPI_P_EQ, CMPI_P_NE, CMPI_P_SLT, CMPI_P_SLE, CMPI_P_SGT, - CMPI_P_SGE, CMPI_P_ULT, CMPI_P_ULE, CMPI_P_UGT, CMPI_P_UGE]> { - let cppNamespace = "::mlir"; -} - #endif // STANDARD_OPS_BASE diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -43,7 +43,7 @@ /// Creates an instance of the StdExpand pass that legalizes Std /// dialect ops to be convertible to LLVM. For example, -/// `std.ceildivi_signed` gets transformed to a number of std operations, +/// `std.arith.ceildivsi` gets transformed to a number of std operations, /// which can be lowered to LLVM; `memref.reshape` gets converted to /// `memref_reinterpret_cast`. std::unique_ptr createStdExpandOpsPass(); diff --git a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h --- a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h @@ -16,6 +16,7 @@ #ifndef MLIR_DIALECT_STANDARDOPS_UTILS_UTILS_H #define MLIR_DIALECT_STANDARDOPS_UTILS_UTILS_H +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" @@ -24,7 +25,7 @@ namespace mlir { /// Matches a ConstantIndexOp. -detail::op_matcher matchConstantIndex(); +detail::op_matcher matchConstantIndex(); /// Detects the `values` produced by a ConstantIndexOp and places the new /// constant in place of the corresponding sentinel value. diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -555,7 +555,7 @@ %idx0 = ... : index // dynamic computation producing the value 1 of index type %idx1 = ... : index - %0 = constant dense<0, 1, 2, 3>: vector<4xi32> + %0 = arith.const dense<0, 1, 2, 3>: vector<4xi32> // extracts values [0, 1] %1 = vector.extract_map %0[%idx0] : vector<4xi32> to vector<2xi32> // extracts values [1, 2] @@ -743,7 +743,7 @@ %idx0 = ... : index // dynamic computation producing the value 1 of index type %idx1 = ... : index / - %0 = constant dense<0, 1, 2, 3>: vector<4xi32> + %0 = arith.const dense<0, 1, 2, 3>: vector<4xi32> // extracts values [0, 1] %1 = vector.extract_map %0[%idx0] : vector<4xi32> to vector<2xi32> // extracts values [1, 2] diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h --- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -173,9 +173,9 @@ /// canonicalizations pattern to propagate and fold the vector /// insert_map/extract_map operations. /// Transforms: -// %v = addf %a, %b : vector<32xf32> +// %v = arith.addf %a, %b : vector<32xf32> /// to: -/// %v = addf %a, %b : vector<32xf32> +/// %v = arith.addf %a, %b : vector<32xf32> /// %ev = vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32> /// %nv = vector.insert_map %ev, %id, 32 : vector<1xf32> into vector<32xf32> Optional diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td --- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -325,7 +325,7 @@ %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32> %1 = vector.extractelement %0[%i0 : i32]: vector<8xf32> %2 = vector.extractelement %0[%i4 : i32]: vector<8xf32> - %d = addf %1, %2 : f32 + %d = arith.addf %1, %2 : f32 ``` }]; let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a, diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1383,11 +1383,11 @@ /// /// Examples: /// ``` -/// %scalar = "std.addf"(%a, %b) : (f32, f32) -> f32 +/// %scalar = "arith.addf"(%a, %b) : (f32, f32) -> f32 /// ``` /// can be tensorized to /// ``` -/// %tensor = "std.addf"(%a, %b) : (tensor, tensor) +/// %tensor = "arith.addf"(%a, %b) : (tensor, tensor) /// -> tensor /// ``` /// diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -16,6 +16,7 @@ #include "mlir/Dialect/AMX/AMXDialect.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" #include "mlir/Dialect/Async/IR/Async.h" @@ -52,6 +53,7 @@ // clang-format off registry.insert - %3 = addf %2, %2 : f32 + %3 = arith.addf %2, %2 : f32 affine.store %3, %arg0[%arg2] : memref<10xf32> } affine.for %arg2 = 0 to 10 { %2 = affine.load %1[%arg2] : memref<10xf32> - %3 = mulf %2, %2 : f32 + %3 = arith.mulf %2, %2 : f32 affine.store %3, %arg1[%arg2] : memref<10xf32> } return @@ -67,10 +67,10 @@ affine.store %cst, %0[0] : memref<1xf32> affine.store %cst, %1[0] : memref<1xf32> %2 = affine.load %1[0] : memref<1xf32> - %3 = mulf %2, %2 : f32 + %3 = arith.mulf %2, %2 : f32 affine.store %3, %arg1[%arg2] : memref<10xf32> %4 = affine.load %0[0] : memref<1xf32> - %5 = addf %4, %4 : f32 + %5 = arith.addf %4, %4 : f32 affine.store %5, %arg0[%arg2] : memref<10xf32> } return @@ -87,7 +87,7 @@ affine.for %arg6 = 0 to 3 { %0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32> %1 = affine.load %arg1[%arg5, %arg6] : memref<10x10xf32> - %2 = mulf %0, %1 : f32 + %2 = arith.mulf %0, %1 : f32 affine.store %2, %arg3[%arg5, %arg6] : memref<10x10xf32> } } @@ -95,7 +95,7 @@ affine.for %arg6 = 0 to 3 { %0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32> %1 = affine.load %arg2[%arg5, %arg6] : memref<10x10xf32> - %2 = addf %0, %1 : f32 + %2 = arith.addf %0, %1 : f32 affine.store %2, %arg4[%arg5, %arg6] : memref<10x10xf32> } } @@ -111,11 +111,11 @@ affine.for %arg6 = 0 to 3 { %0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32> %1 = affine.load %arg1[%arg5, %arg6] : memref<10x10xf32> - %2 = mulf %0, %1 : f32 + %2 = arith.mulf %0, %1 : f32 affine.store %2, %arg3[%arg5, %arg6] : memref<10x10xf32> %3 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32> %4 = affine.load %arg2[%arg5, %arg6] : memref<10x10xf32> - %5 = addf %3, %4 : f32 + %5 = arith.addf %3, %4 : f32 affine.store %5, %arg4[%arg5, %arg6] : memref<10x10xf32> } } @@ -481,6 +481,7 @@ let summary = "Coalesce nested loops with independent bounds into a single " "loop"; let constructor = "mlir::createLoopCoalescingPass()"; + let dependentDialects = ["arith::ArithmeticDialect"]; } def LoopInvariantCodeMotion : Pass<"loop-invariant-code-motion"> { @@ -524,7 +525,7 @@ %B: index, %C: memref<16xf64>) -> (memref<16xf64, #tile>) { affine.for %arg3 = 0 to 16 { %a = affine.load %A[%arg3] : memref<16xf64, #tile> - %p = mulf %a, %a : f64 + %p = arith.mulf %a, %a : f64 affine.store %p, %A[%arg3] : memref<16xf64, #tile> } %c = alloc() : memref<16xf64, #tile> @@ -540,7 +541,7 @@ -> memref<4x4xf64> { affine.for %arg3 = 0 to 16 { %3 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64> - %4 = mulf %3, %3 : f64 + %4 = arith.mulf %3, %3 : f64 affine.store %4, %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64> } %0 = alloc() : memref<4x4xf64> @@ -566,8 +567,8 @@ %0 = affine.load %arg0[%arg3, %arg5] : memref<8x8xi32, #linear8> %1 = affine.load %arg1[%arg5, %arg4] : memref<8x8xi32, #linear8> %2 = affine.load %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8> - %3 = muli %0, %1 : i32 - %4 = addi %2, %3 : i32 + %3 = arith.muli %0, %1 : i32 + %4 = arith.addi %2, %3 : i32 affine.store %4, %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8> } } @@ -590,8 +591,8 @@ %0 = affine.load %arg0[%arg3 * 8 + %arg5] : memref<64xi32> %1 = affine.load %arg1[%arg5 * 8 + %arg4] : memref<64xi32> %2 = affine.load %arg2[%arg3 * 8 + %arg4] : memref<64xi32> - %3 = muli %0, %1 : i32 - %4 = addi %2, %3 : i32 + %3 = arith.muli %0, %1 : i32 + %4 = arith.addi %2, %3 : i32 affine.store %4, %arg2[%arg3 * 8 + %arg4] : memref<64xi32> } } diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -17,6 +17,7 @@ #include "mlir/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/BuiltinOps.h" @@ -53,10 +54,10 @@ Operation *combinerOp = combinerOps.back(); Optional maybeKind = TypeSwitch>(combinerOp) - .Case([](Operation *) { return AtomicRMWKind::addf; }) - .Case([](Operation *) { return AtomicRMWKind::mulf; }) - .Case([](Operation *) { return AtomicRMWKind::addi; }) - .Case([](Operation *) { return AtomicRMWKind::muli; }) + .Case([](Operation *) { return AtomicRMWKind::addf; }) + .Case([](Operation *) { return AtomicRMWKind::mulf; }) + .Case([](Operation *) { return AtomicRMWKind::addi; }) + .Case([](Operation *) { return AtomicRMWKind::muli; }) .Default([](Operation *) -> Optional { // TODO: AtomicRMW supports other kinds of reductions this is // currently not detecting, add those when the need arises. @@ -632,10 +633,9 @@ auto symbol = operands[i]; assert(isValidSymbol(symbol)); // Check if the symbol is a constant. - if (auto cOp = symbol.getDefiningOp()) + if (auto cOp = symbol.getDefiningOp()) dependenceDomain->addBound(FlatAffineConstraints::EQ, - valuePosMap.getSymPos(symbol), - cOp.getValue()); + valuePosMap.getSymPos(symbol), cOp.value()); } }; diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -15,6 +15,7 @@ #include "mlir/Analysis/Presburger/Simplex.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/IntegerSet.h" @@ -654,8 +655,8 @@ // Add top level symbol. appendSymbolId(val); // Check if the symbol is a constant. - if (auto constOp = val.getDefiningOp()) - addBound(BoundType::EQ, val, constOp.getValue()); + if (auto constOp = val.getDefiningOp()) + addBound(BoundType::EQ, val, constOp.value()); } LogicalResult diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -38,6 +38,7 @@ LINK_LIBS PUBLIC MLIRAffine + MLIRArithmetic MLIRCallInterfaces MLIRControlFlowInterfaces MLIRDataLayoutInterfaces diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -17,6 +17,7 @@ #include "mlir/Analysis/PresburgerSet.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/IntegerSet.h" #include "llvm/ADT/SmallPtrSet.h" @@ -98,8 +99,8 @@ assert(cst->containsId(value) && "value expected to be present"); if (isValidSymbol(value)) { // Check if the symbol is a constant. - if (auto cOp = value.getDefiningOp()) - cst->addBound(FlatAffineConstraints::EQ, value, cOp.getValue()); + if (auto cOp = value.getDefiningOp()) + cst->addBound(FlatAffineConstraints::EQ, value, cOp.value()); } else if (auto loop = getForInductionVarOwner(value)) { if (failed(cst->addAffineForOpDomain(loop))) return failure(); @@ -517,8 +518,8 @@ assert(isValidSymbol(symbol)); // Check if the symbol is a constant. if (auto *op = symbol.getDefiningOp()) { - if (auto constOp = dyn_cast(op)) { - cst.addBound(FlatAffineConstraints::EQ, symbol, constOp.getValue()); + if (auto constOp = dyn_cast(op)) { + cst.addBound(FlatAffineConstraints::EQ, symbol, constOp.value()); } } } diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -15,6 +15,7 @@ #include "../PassDetail.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -56,11 +57,11 @@ } Value visitAddExpr(AffineBinaryOpExpr expr) { - return buildBinaryExpr(expr); + return buildBinaryExpr(expr); } Value visitMulExpr(AffineBinaryOpExpr expr) { - return buildBinaryExpr(expr); + return buildBinaryExpr(expr); } /// Euclidean modulo operation: negative RHS is not allowed. @@ -89,11 +90,12 @@ auto rhs = visit(expr.getRHS()); assert(lhs && rhs && "unexpected affine expr lowering failure"); - Value remainder = builder.create(loc, lhs, rhs); - Value zeroCst = builder.create(loc, 0); - Value isRemainderNegative = - builder.create(loc, CmpIPredicate::slt, remainder, zeroCst); - Value correctedRemainder = builder.create(loc, remainder, rhs); + Value remainder = builder.create(loc, lhs, rhs); + Value zeroCst = builder.create(loc, 0); + Value isRemainderNegative = builder.create( + loc, arith::CmpIPredicate::slt, remainder, zeroCst); + Value correctedRemainder = + builder.create(loc, remainder, rhs); Value result = builder.create(loc, isRemainderNegative, correctedRemainder, remainder); return result; @@ -126,15 +128,16 @@ auto rhs = visit(expr.getRHS()); assert(lhs && rhs && "unexpected affine expr lowering failure"); - Value zeroCst = builder.create(loc, 0); - Value noneCst = builder.create(loc, -1); - Value negative = - builder.create(loc, CmpIPredicate::slt, lhs, zeroCst); - Value negatedDecremented = builder.create(loc, noneCst, lhs); + Value zeroCst = builder.create(loc, 0); + Value noneCst = builder.create(loc, -1); + Value negative = builder.create( + loc, arith::CmpIPredicate::slt, lhs, zeroCst); + Value negatedDecremented = builder.create(loc, noneCst, lhs); Value dividend = builder.create(loc, negative, negatedDecremented, lhs); - Value quotient = builder.create(loc, dividend, rhs); - Value correctedQuotient = builder.create(loc, noneCst, quotient); + Value quotient = builder.create(loc, dividend, rhs); + Value correctedQuotient = + builder.create(loc, noneCst, quotient); Value result = builder.create(loc, negative, correctedQuotient, quotient); return result; @@ -165,27 +168,26 @@ auto rhs = visit(expr.getRHS()); assert(lhs && rhs && "unexpected affine expr lowering failure"); - Value zeroCst = builder.create(loc, 0); - Value oneCst = builder.create(loc, 1); - Value nonPositive = - builder.create(loc, CmpIPredicate::sle, lhs, zeroCst); - Value negated = builder.create(loc, zeroCst, lhs); - Value decremented = builder.create(loc, lhs, oneCst); + Value zeroCst = builder.create(loc, 0); + Value oneCst = builder.create(loc, 1); + Value nonPositive = builder.create( + loc, arith::CmpIPredicate::sle, lhs, zeroCst); + Value negated = builder.create(loc, zeroCst, lhs); + Value decremented = builder.create(loc, lhs, oneCst); Value dividend = builder.create(loc, nonPositive, negated, decremented); - Value quotient = builder.create(loc, dividend, rhs); - Value negatedQuotient = builder.create(loc, zeroCst, quotient); - Value incrementedQuotient = builder.create(loc, quotient, oneCst); + Value quotient = builder.create(loc, dividend, rhs); + Value negatedQuotient = + builder.create(loc, zeroCst, quotient); + Value incrementedQuotient = + builder.create(loc, quotient, oneCst); Value result = builder.create(loc, nonPositive, negatedQuotient, incrementedQuotient); return result; } Value visitConstantExpr(AffineConstantExpr expr) { - auto valueAttr = - builder.getIntegerAttr(builder.getIndexType(), expr.getValue()); - auto op = - builder.create(loc, builder.getIndexType(), valueAttr); + auto op = builder.create(loc, expr.getValue()); return op.getResult(); } @@ -242,20 +244,21 @@ /// comparison to perform, "lt" for "min", "gt" for "max" and is used for the /// `cmpi` operation followed by the `select` operation: /// -/// %cond = cmpi "predicate" %v0, %v1 +/// %cond = arith.cmpi "predicate" %v0, %v1 /// %result = select %cond, %v0, %v1 /// /// Multiple values are scanned in a linear sequence. This creates a data /// dependences that wouldn't exist in a tree reduction, but is easier to /// recognize as a reduction by the subsequent passes. -static Value buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate, +static Value buildMinMaxReductionSeq(Location loc, + arith::CmpIPredicate predicate, ValueRange values, OpBuilder &builder) { assert(!llvm::empty(values) && "empty min/max chain"); auto valueIt = values.begin(); Value value = *valueIt++; for (; valueIt != values.end(); ++valueIt) { - auto cmpOp = builder.create(loc, predicate, value, *valueIt); + auto cmpOp = builder.create(loc, predicate, value, *valueIt); value = builder.create(loc, cmpOp.getResult(), value, *valueIt); } @@ -267,7 +270,8 @@ static Value lowerAffineMapMax(OpBuilder &builder, Location loc, AffineMap map, ValueRange operands) { if (auto values = expandAffineMap(builder, loc, map, operands)) - return buildMinMaxReductionSeq(loc, CmpIPredicate::sgt, *values, builder); + return buildMinMaxReductionSeq(loc, arith::CmpIPredicate::sgt, *values, + builder); return nullptr; } @@ -276,7 +280,8 @@ static Value lowerAffineMapMin(OpBuilder &builder, Location loc, AffineMap map, ValueRange operands) { if (auto values = expandAffineMap(builder, loc, map, operands)) - return buildMinMaxReductionSeq(loc, CmpIPredicate::slt, *values, builder); + return buildMinMaxReductionSeq(loc, arith::CmpIPredicate::slt, *values, + builder); return nullptr; } @@ -356,7 +361,7 @@ Location loc = op.getLoc(); Value lowerBound = lowerAffineLowerBound(op, rewriter); Value upperBound = lowerAffineUpperBound(op, rewriter); - Value step = rewriter.create(loc, op.getStep()); + Value step = rewriter.create(loc, op.getStep()); auto scfForOp = rewriter.create(loc, lowerBound, upperBound, step, op.getIterOperands()); rewriter.eraseBlock(scfForOp.getBody()); @@ -399,7 +404,7 @@ } steps.reserve(op.steps().size()); for (Attribute step : op.steps()) - steps.push_back(rewriter.create( + steps.push_back(rewriter.create( loc, step.cast().getInt())); // Get the terminator op. @@ -475,7 +480,7 @@ // Now we just have to handle the condition logic. auto integerSet = op.getIntegerSet(); - Value zeroConstant = rewriter.create(loc, 0); + Value zeroConstant = rewriter.create(loc, 0); SmallVector operands(op.getOperands()); auto operandsRef = llvm::makeArrayRef(operands); @@ -492,14 +497,17 @@ operandsRef.drop_front(numDims)); if (!affResult) return failure(); - auto pred = isEquality ? CmpIPredicate::eq : CmpIPredicate::sge; + auto pred = + isEquality ? arith::CmpIPredicate::eq : arith::CmpIPredicate::sge; Value cmpVal = - rewriter.create(loc, pred, affResult, zeroConstant); - cond = - cond ? rewriter.create(loc, cond, cmpVal).getResult() : cmpVal; + rewriter.create(loc, pred, affResult, zeroConstant); + cond = cond + ? rewriter.create(loc, cond, cmpVal).getResult() + : cmpVal; } cond = cond ? cond - : rewriter.create(loc, /*value=*/1, /*width=*/1); + : rewriter.create(loc, /*value=*/1, + /*width=*/1); bool hasElseRegion = !op.elseRegion().empty(); auto ifOp = rewriter.create(loc, op.getResultTypes(), cond, @@ -750,8 +758,9 @@ populateAffineToStdConversionPatterns(patterns); populateAffineToVectorConversionPatterns(patterns); ConversionTarget target(getContext()); - target.addLegalDialect(); + target + .addLegalDialect(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); diff --git a/mlir/lib/Conversion/AffineToStandard/CMakeLists.txt b/mlir/lib/Conversion/AffineToStandard/CMakeLists.txt --- a/mlir/lib/Conversion/AffineToStandard/CMakeLists.txt +++ b/mlir/lib/Conversion/AffineToStandard/CMakeLists.txt @@ -12,6 +12,7 @@ LINK_LIBS PUBLIC MLIRAffine + MLIRArithmetic MLIRMemRef MLIRSCF MLIRPass diff --git a/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp b/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp @@ -0,0 +1,256 @@ +//===- ArithmeticToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" +#include "../PassDetail.h" +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/TypeUtilities.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Operation Conversion +//===----------------------------------------------------------------------===// + +namespace { + +/// Straightforward lowerings. +using AddIOpLowering = VectorConvertToLLVMPattern; +using SubIOpLowering = VectorConvertToLLVMPattern; +using MulIOpLowering = VectorConvertToLLVMPattern; +using DivUIOpLowering = + VectorConvertToLLVMPattern; +using DivSIOpLowering = + VectorConvertToLLVMPattern; +using RemUIOpLowering = + VectorConvertToLLVMPattern; +using RemSIOpLowering = + VectorConvertToLLVMPattern; +using AndIOpLowering = VectorConvertToLLVMPattern; +using OrIOpLowering = VectorConvertToLLVMPattern; +using XOrIOpLowering = VectorConvertToLLVMPattern; +using ShLIOpLowering = VectorConvertToLLVMPattern; +using ShRUIOpLowering = + VectorConvertToLLVMPattern; +using ShRSIOpLowering = + VectorConvertToLLVMPattern; +using NegFOpLowering = VectorConvertToLLVMPattern; +using AddFOpLowering = VectorConvertToLLVMPattern; +using SubFOpLowering = VectorConvertToLLVMPattern; +using MulFOpLowering = VectorConvertToLLVMPattern; +using DivFOpLowering = VectorConvertToLLVMPattern; +using RemFOpLowering = VectorConvertToLLVMPattern; +using ExtUIOpLowering = + VectorConvertToLLVMPattern; +using ExtSIOpLowering = + VectorConvertToLLVMPattern; +using ExtFOpLowering = VectorConvertToLLVMPattern; +using TruncIOpLowering = + VectorConvertToLLVMPattern; +using TruncFOpLowering = + VectorConvertToLLVMPattern; +using UIToFPOpLowering = + VectorConvertToLLVMPattern; +using SIToFPOpLowering = + VectorConvertToLLVMPattern; +using FPToUIOpLowering = + VectorConvertToLLVMPattern; +using FPToSIOpLowering = + VectorConvertToLLVMPattern; +using BitcastOpLowering = + VectorConvertToLLVMPattern; + +/// Directly lower to LLVM op. +struct ConstantOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return LLVM::detail::oneToOneRewrite( + op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(), + *getTypeConverter(), rewriter); + } +}; + +/// The lowering of index_cast becomes an integer conversion since index +/// becomes an integer. If the bit width of the source and target integer +/// types is the same, just erase the cast. If the target type is wider, +/// sign-extend the value, otherwise truncate it. +struct IndexCastOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto targetType = typeConverter->convertType(op.getResult().getType()); + auto targetElementType = + typeConverter->convertType(getElementTypeOrSelf(op.getResult())) + .cast(); + auto sourceElementType = + getElementTypeOrSelf(adaptor.in()).cast(); + unsigned targetBits = targetElementType.getWidth(); + unsigned sourceBits = sourceElementType.getWidth(); + + if (targetBits == sourceBits) + rewriter.replaceOp(op, adaptor.in()); + else if (targetBits < sourceBits) + rewriter.replaceOpWithNewOp(op, targetType, adaptor.in()); + else + rewriter.replaceOpWithNewOp(op, targetType, adaptor.in()); + return success(); + } +}; + +// Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums +// share numerical values so just cast. +template +static LLVMPredType convertCmpPredicate(PredType pred) { + return static_cast(pred); +} + +struct CmpIOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto operandType = adaptor.lhs().getType(); + auto resultType = op.getResult().getType(); + + // Handle the scalar and 1D vector cases. + if (!operandType.isa()) { + rewriter.replaceOpWithNewOp( + op, typeConverter->convertType(resultType), + convertCmpPredicate(op.getPredicate()), + adaptor.lhs(), adaptor.rhs()); + return success(); + } + + auto vectorType = resultType.dyn_cast(); + if (!vectorType) + return rewriter.notifyMatchFailure(op, "expected vector result type"); + + return LLVM::detail::handleMultidimensionalVectors( + op.getOperation(), adaptor.getOperands(), *getTypeConverter(), + [&](Type llvm1DVectorTy, ValueRange operands) { + OpAdaptor adaptor(operands); + return rewriter.create( + op.getLoc(), llvm1DVectorTy, + convertCmpPredicate(op.getPredicate()), + adaptor.lhs(), adaptor.rhs()); + }, + rewriter); + + return success(); + } +}; + +struct CmpFOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto operandType = adaptor.lhs().getType(); + auto resultType = op.getResult().getType(); + + // Handle the scalar and 1D vector cases. + if (!operandType.isa()) { + rewriter.replaceOpWithNewOp( + op, typeConverter->convertType(resultType), + convertCmpPredicate(op.getPredicate()), + adaptor.lhs(), adaptor.rhs()); + return success(); + } + + auto vectorType = resultType.dyn_cast(); + if (!vectorType) + return rewriter.notifyMatchFailure(op, "expected vector result type"); + + return LLVM::detail::handleMultidimensionalVectors( + op.getOperation(), adaptor.getOperands(), *getTypeConverter(), + [&](Type llvm1DVectorTy, ValueRange operands) { + OpAdaptor adaptor(operands); + return rewriter.create( + op.getLoc(), llvm1DVectorTy, + convertCmpPredicate(op.getPredicate()), + adaptor.lhs(), adaptor.rhs()); + }, + rewriter); + } +}; + +struct ConvertArithmeticToLLVMPass + : public ConvertArithmeticToLLVMBase { + ConvertArithmeticToLLVMPass() = default; + + void runOnFunction() override { + RewritePatternSet patterns(&getContext()); + LLVMTypeConverter converter(&getContext()); + populateArithmeticToLLVMConversionPatterns(converter, patterns); + LLVMConversionTarget target(getContext()); + if (failed( + applyPartialConversion(getFunction(), target, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// Pattern Population +//===----------------------------------------------------------------------===// + +void mlir::populateArithmeticToLLVMConversionPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns) { + // clang-format off + patterns.add< + ConstantOpLowering, + AddIOpLowering, + SubIOpLowering, + MulIOpLowering, + DivUIOpLowering, + DivSIOpLowering, + RemUIOpLowering, + RemSIOpLowering, + AndIOpLowering, + OrIOpLowering, + XOrIOpLowering, + ShLIOpLowering, + ShRUIOpLowering, + ShRSIOpLowering, + NegFOpLowering, + AddFOpLowering, + SubFOpLowering, + MulFOpLowering, + DivFOpLowering, + RemFOpLowering, + ExtUIOpLowering, + ExtSIOpLowering, + ExtFOpLowering, + TruncIOpLowering, + TruncFOpLowering, + UIToFPOpLowering, + SIToFPOpLowering, + FPToUIOpLowering, + FPToSIOpLowering, + IndexCastOpLowering, + BitcastOpLowering, + CmpIOpLowering, + CmpFOpLowering + >(converter); + // clang-format on +} + +std::unique_ptr mlir::createConvertArithmeticToLLVMPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Conversion/ArithmeticToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ArithmeticToLLVM/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArithmeticToLLVM/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_conversion_library(MLIRArithmeticToLLVM + ArithmeticToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithmeticToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRLLVMCommonConversion + MLIRLLVMIR + ) diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp @@ -0,0 +1,796 @@ +//===- ArithmeticToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/SPIRVCommon/Pattern.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "arith-to-spirv-pattern" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Operation Conversion +//===----------------------------------------------------------------------===// + +namespace { + +/// Converts composite arith.const operation to spv.Constant. +struct ConstantCompositeOpPattern final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts scalar arith.const operation to spv.Constant. +struct ConstantScalarOpPattern final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts arith.remsi to SPIR-V ops. +/// +/// This cannot be merged into the template unary/binary pattern due to Vulkan +/// restrictions over spv.SRem and spv.SMod. +struct RemSIOpPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts bitwise operations to SPIR-V operations. This is a special pattern +/// other than the BinaryOpPatternPattern because if the operands are boolean +/// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For +/// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. +template +struct BitwiseOpPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts arith.xori to SPIR-V operations. +struct XOrIOpLogicalPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts arith.xori to SPIR-V operations if the type of source is i1 or +/// vector of i1. +struct XOrIOpBooleanPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts arith.uitofp to spv.Select if the type of source is i1 or vector of +/// i1. +struct UIToFPI1Pattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts arith.extui to spv.Select if the type of source is i1 or vector of +/// i1. +struct ExtUII1Pattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts arith.trunci to spv.Select if the type of result is i1 or vector of +/// i1. +struct TruncII1Pattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts type-casting standard operations to SPIR-V operations. +template +struct TypeCastingOpPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts integer compare operation on i1 type operands to SPIR-V ops. +class CmpIOpBooleanPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts integer compare operation to SPIR-V ops. +class CmpIOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts floating-point comparison operations to SPIR-V ops. +class CmpFOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts floating point NaN check to SPIR-V ops. This pattern requires +/// Kernel capability. +class CmpFOpNanKernelPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts floating point NaN check to SPIR-V ops. This pattern does not +/// require additional capability. +class CmpFOpNanNonePattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// Conversion Helpers +//===----------------------------------------------------------------------===// + +/// Converts the given `srcAttr` into a boolean attribute if it holds an +/// integral value. Returns null attribute if conversion fails. +static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) { + if (auto boolAttr = srcAttr.dyn_cast()) + return boolAttr; + if (auto intAttr = srcAttr.dyn_cast()) + return builder.getBoolAttr(intAttr.getValue().getBoolValue()); + return BoolAttr(); +} + +/// Converts the given `srcAttr` to a new attribute of the given `dstType`. +/// Returns null attribute if conversion fails. +static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, + Builder builder) { + // If the source number uses less active bits than the target bitwidth, then + // it should be safe to convert. + if (srcAttr.getValue().isIntN(dstType.getWidth())) + return builder.getIntegerAttr(dstType, srcAttr.getInt()); + + // XXX: Try again by interpreting the source number as a signed value. + // Although integers in the standard dialect are signless, they can represent + // a signed number. It's the operation decides how to interpret. This is + // dangerous, but it seems there is no good way of handling this if we still + // want to change the bitwidth. Emit a message at least. + if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) { + auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt()); + LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '" + << dstAttr << "' for type '" << dstType << "'\n"); + return dstAttr; + } + + LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr + << "' illegal: cannot fit into target type '" + << dstType << "'\n"); + return IntegerAttr(); +} + +/// Converts the given `srcAttr` to a new attribute of the given `dstType`. +/// Returns null attribute if `dstType` is not 32-bit or conversion fails. +static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, + Builder builder) { + // Only support converting to float for now. + if (!dstType.isF32()) + return FloatAttr(); + + // Try to convert the source floating-point number to single precision. + APFloat dstVal = srcAttr.getValue(); + bool losesInfo = false; + APFloat::opStatus status = + dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo); + if (status != APFloat::opOK || losesInfo) { + LLVM_DEBUG(llvm::dbgs() + << srcAttr << " illegal: cannot fit into converted type '" + << dstType << "'\n"); + return FloatAttr(); + } + + return builder.getF32FloatAttr(dstVal.convertToFloat()); +} + +/// Returns true if the given `type` is a boolean scalar or vector type. +static bool isBoolScalarOrVector(Type type) { + if (type.isInteger(1)) + return true; + if (auto vecType = type.dyn_cast()) + return vecType.getElementType().isInteger(1); + return false; +} + +//===----------------------------------------------------------------------===// +// ConstantOp with composite type +//===----------------------------------------------------------------------===// + +LogicalResult ConstantCompositeOpPattern::matchAndRewrite( + arith::ConstantOp constOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto srcType = constOp.getType().dyn_cast(); + if (!srcType) + return failure(); + + // std.constant should only have vector or tenor types. + assert((srcType.isa())); + + auto dstType = getTypeConverter()->convertType(srcType); + if (!dstType) + return failure(); + + auto dstElementsAttr = constOp.value().dyn_cast(); + ShapedType dstAttrType = dstElementsAttr.getType(); + if (!dstElementsAttr) + return failure(); + + // If the composite type has more than one dimensions, perform linearization. + if (srcType.getRank() > 1) { + if (srcType.isa()) { + dstAttrType = RankedTensorType::get(srcType.getNumElements(), + srcType.getElementType()); + dstElementsAttr = dstElementsAttr.reshape(dstAttrType); + } else { + // TODO: add support for large vectors. + return failure(); + } + } + + Type srcElemType = srcType.getElementType(); + Type dstElemType; + // Tensor types are converted to SPIR-V array types; vector types are + // converted to SPIR-V vector/array types. + if (auto arrayType = dstType.dyn_cast()) + dstElemType = arrayType.getElementType(); + else + dstElemType = dstType.cast().getElementType(); + + // If the source and destination element types are different, perform + // attribute conversion. + if (srcElemType != dstElemType) { + SmallVector elements; + if (srcElemType.isa()) { + for (FloatAttr srcAttr : dstElementsAttr.getValues()) { + FloatAttr dstAttr = + convertFloatAttr(srcAttr, dstElemType.cast(), rewriter); + if (!dstAttr) + return failure(); + elements.push_back(dstAttr); + } + } else if (srcElemType.isInteger(1)) { + return failure(); + } else { + for (IntegerAttr srcAttr : dstElementsAttr.getValues()) { + IntegerAttr dstAttr = convertIntegerAttr( + srcAttr, dstElemType.cast(), rewriter); + if (!dstAttr) + return failure(); + elements.push_back(dstAttr); + } + } + + // Unfortunately, we cannot use dialect-specific types for element + // attributes; element attributes only works with builtin types. So we need + // to prepare another converted builtin types for the destination elements + // attribute. + if (dstAttrType.isa()) + dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType); + else + dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType); + + dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements); + } + + rewriter.replaceOpWithNewOp(constOp, dstType, + dstElementsAttr); + return success(); +} + +//===----------------------------------------------------------------------===// +// ConstantOp with scalar type +//===----------------------------------------------------------------------===// + +LogicalResult ConstantScalarOpPattern::matchAndRewrite( + arith::ConstantOp constOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Type srcType = constOp.getType(); + if (!srcType.isIntOrIndexOrFloat()) + return failure(); + + Type dstType = getTypeConverter()->convertType(srcType); + if (!dstType) + return failure(); + + // Floating-point types. + if (srcType.isa()) { + auto srcAttr = constOp.value().cast(); + auto dstAttr = srcAttr; + + // Floating-point types not supported in the target environment are all + // converted to float type. + if (srcType != dstType) { + dstAttr = convertFloatAttr(srcAttr, dstType.cast(), rewriter); + if (!dstAttr) + return failure(); + } + + rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); + return success(); + } + + // Bool type. + if (srcType.isInteger(1)) { + // std.constant can use 0/1 instead of true/false for i1 values. We need to + // handle that here. + auto dstAttr = convertBoolAttr(constOp.value(), rewriter); + if (!dstAttr) + return failure(); + rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); + return success(); + } + + // IndexType or IntegerType. Index values are converted to 32-bit integer + // values when converting to SPIR-V. + auto srcAttr = constOp.value().cast(); + auto dstAttr = + convertIntegerAttr(srcAttr, dstType.cast(), rewriter); + if (!dstAttr) + return failure(); + rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); + return success(); +} + +//===----------------------------------------------------------------------===// +// RemSIOpPattern +//===----------------------------------------------------------------------===// + +/// Returns signed remainder for `lhs` and `rhs` and lets the result follow +/// the sign of `signOperand`. +/// +/// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment +/// spec, "for the OpSRem and OpSMod instructions, if either operand is negative +/// the result is undefined." So we cannot directly use spv.SRem/spv.SMod +/// if either operand can be negative. Emulate it via spv.UMod. +static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, + Value signOperand, OpBuilder &builder) { + assert(lhs.getType() == rhs.getType()); + assert(lhs == signOperand || rhs == signOperand); + + Type type = lhs.getType(); + + // Calculate the remainder with spv.UMod. + Value lhsAbs = builder.create(loc, type, lhs); + Value rhsAbs = builder.create(loc, type, rhs); + Value abs = builder.create(loc, lhsAbs, rhsAbs); + + // Fix the sign. + Value isPositive; + if (lhs == signOperand) + isPositive = builder.create(loc, lhs, lhsAbs); + else + isPositive = builder.create(loc, rhs, rhsAbs); + Value absNegate = builder.create(loc, type, abs); + return builder.create(loc, type, isPositive, abs, absNegate); +} + +LogicalResult +RemSIOpPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value result = emulateSignedRemainder(op.getLoc(), adaptor.getOperands()[0], + adaptor.getOperands()[1], + adaptor.getOperands()[0], rewriter); + rewriter.replaceOp(op, result); + + return success(); +} + +//===----------------------------------------------------------------------===// +// BitwiseOpPattern +//===----------------------------------------------------------------------===// + +template +LogicalResult +BitwiseOpPattern::matchAndRewrite( + Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + assert(adaptor.getOperands().size() == 2); + auto dstType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!dstType) + return failure(); + if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) { + rewriter.template replaceOpWithNewOp(op, dstType, + adaptor.getOperands()); + } else { + rewriter.template replaceOpWithNewOp(op, dstType, + adaptor.getOperands()); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// XOrIOpLogicalPattern +//===----------------------------------------------------------------------===// + +LogicalResult XOrIOpLogicalPattern::matchAndRewrite( + arith::XOrIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + assert(adaptor.getOperands().size() == 2); + + if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) + return failure(); + + auto dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) + return failure(); + rewriter.replaceOpWithNewOp(op, dstType, + adaptor.getOperands()); + + return success(); +} + +//===----------------------------------------------------------------------===// +// XOrIOpBooleanPattern +//===----------------------------------------------------------------------===// + +LogicalResult XOrIOpBooleanPattern::matchAndRewrite( + arith::XOrIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + assert(adaptor.getOperands().size() == 2); + + if (!isBoolScalarOrVector(adaptor.getOperands().front().getType())) + return failure(); + + auto dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) + return failure(); + rewriter.replaceOpWithNewOp(op, dstType, + adaptor.getOperands()); + return success(); +} + +//===----------------------------------------------------------------------===// +// UIToFPI1Pattern +//===----------------------------------------------------------------------===// + +LogicalResult +UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto srcType = adaptor.getOperands().front().getType(); + if (!isBoolScalarOrVector(srcType)) + return failure(); + + auto dstType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Location loc = op.getLoc(); + Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); + Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); + rewriter.template replaceOpWithNewOp( + op, dstType, adaptor.getOperands().front(), one, zero); + return success(); +} + +//===----------------------------------------------------------------------===// +// ExtUII1Pattern +//===----------------------------------------------------------------------===// + +LogicalResult +ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto srcType = adaptor.getOperands().front().getType(); + if (!isBoolScalarOrVector(srcType)) + return failure(); + + auto dstType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Location loc = op.getLoc(); + Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); + Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); + rewriter.template replaceOpWithNewOp( + op, dstType, adaptor.getOperands().front(), one, zero); + return success(); +} + +//===----------------------------------------------------------------------===// +// TruncII1Pattern +//===----------------------------------------------------------------------===// + +LogicalResult +TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto dstType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!isBoolScalarOrVector(dstType)) + return failure(); + + Location loc = op.getLoc(); + auto srcType = adaptor.getOperands().front().getType(); + // Check if (x & 1) == 1. + Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); + Value maskedSrc = rewriter.create( + loc, srcType, adaptor.getOperands()[0], mask); + Value isOne = rewriter.create(loc, maskedSrc, mask); + + Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); + Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); + rewriter.replaceOpWithNewOp(op, dstType, isOne, one, zero); + return success(); +} + +//===----------------------------------------------------------------------===// +// TypeCastingOpPattern +//===----------------------------------------------------------------------===// + +template +LogicalResult TypeCastingOpPattern::matchAndRewrite( + Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + assert(adaptor.getOperands().size() == 1); + auto srcType = adaptor.getOperands().front().getType(); + auto dstType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) + return failure(); + if (dstType == srcType) { + // Due to type conversion, we are seeing the same source and target type. + // Then we can just erase this operation by forwarding its operand. + rewriter.replaceOp(op, adaptor.getOperands().front()); + } else { + rewriter.template replaceOpWithNewOp(op, dstType, + adaptor.getOperands()); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// CmpIOpBooleanPattern +//===----------------------------------------------------------------------===// + +LogicalResult CmpIOpBooleanPattern::matchAndRewrite( + arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Type operandType = op.lhs().getType(); + if (!isBoolScalarOrVector(operandType)) + return failure(); + + switch (op.getPredicate()) { +#define DISPATCH(cmpPredicate, spirvOp) \ + case cmpPredicate: \ + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), \ + adaptor.lhs(), adaptor.rhs()); \ + return success(); + + DISPATCH(arith::CmpIPredicate::eq, spirv::LogicalEqualOp); + DISPATCH(arith::CmpIPredicate::ne, spirv::LogicalNotEqualOp); + +#undef DISPATCH + default:; + } + return failure(); +} + +//===----------------------------------------------------------------------===// +// CmpIOpPattern +//===----------------------------------------------------------------------===// + +LogicalResult +CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Type operandType = op.lhs().getType(); + if (isBoolScalarOrVector(operandType)) + return failure(); + + switch (op.getPredicate()) { +#define DISPATCH(cmpPredicate, spirvOp) \ + case cmpPredicate: \ + if (spirvOp::template hasTrait() && \ + operandType != this->getTypeConverter()->convertType(operandType)) { \ + return op.emitError( \ + "bitwidth emulation is not implemented yet on unsigned op"); \ + } \ + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), \ + adaptor.lhs(), adaptor.rhs()); \ + return success(); + + DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp); + DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp); + DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp); + DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp); + DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp); + DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp); + DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp); + DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp); + DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp); + DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp); + +#undef DISPATCH + } + return failure(); +} + +//===----------------------------------------------------------------------===// +// CmpFOpPattern +//===----------------------------------------------------------------------===// + +LogicalResult +CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + switch (op.getPredicate()) { +#define DISPATCH(cmpPredicate, spirvOp) \ + case cmpPredicate: \ + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), \ + adaptor.lhs(), adaptor.rhs()); \ + return success(); + + // Ordered. + DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp); + DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); + DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); + DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp); + DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); + DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp); + // Unordered. + DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp); + DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); + DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); + DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp); + DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); + DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp); + +#undef DISPATCH + + default: + break; + } + return failure(); +} + +//===----------------------------------------------------------------------===// +// CmpFOpNanKernelPattern +//===----------------------------------------------------------------------===// + +LogicalResult CmpFOpNanKernelPattern::matchAndRewrite( + arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + if (op.getPredicate() == arith::CmpFPredicate::ORD) { + rewriter.replaceOpWithNewOp(op, adaptor.lhs(), + adaptor.rhs()); + return success(); + } + + if (op.getPredicate() == arith::CmpFPredicate::UNO) { + rewriter.replaceOpWithNewOp(op, adaptor.lhs(), + adaptor.rhs()); + return success(); + } + + return failure(); +} + +//===----------------------------------------------------------------------===// +// CmpFOpNanNonePattern +//===----------------------------------------------------------------------===// + +LogicalResult CmpFOpNanNonePattern::matchAndRewrite( + arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + if (op.getPredicate() != arith::CmpFPredicate::ORD && + op.getPredicate() != arith::CmpFPredicate::UNO) + return failure(); + + Location loc = op.getLoc(); + + Value lhsIsNan = rewriter.create(loc, adaptor.lhs()); + Value rhsIsNan = rewriter.create(loc, adaptor.rhs()); + + Value replace = rewriter.create(loc, lhsIsNan, rhsIsNan); + if (op.getPredicate() == arith::CmpFPredicate::ORD) + replace = rewriter.create(loc, replace); + + rewriter.replaceOp(op, replace); + return success(); +} + +//===----------------------------------------------------------------------===// +// Pattern Population +//===----------------------------------------------------------------------===// + +namespace mlir { +void populateArithmeticToSPIRVPatterns(SPIRVTypeConverter &typeConverter, + RewritePatternSet &patterns) { + // clang-format off + patterns.add< + ConstantCompositeOpPattern, + ConstantScalarOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + RemSIOpPattern, + BitwiseOpPattern, + BitwiseOpPattern, + XOrIOpLogicalPattern, XOrIOpBooleanPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + TypeCastingOpPattern, ExtUII1Pattern, + TypeCastingOpPattern, + TypeCastingOpPattern, + TypeCastingOpPattern, TruncII1Pattern, + TypeCastingOpPattern, + TypeCastingOpPattern, UIToFPI1Pattern, + TypeCastingOpPattern, + TypeCastingOpPattern, + TypeCastingOpPattern, + CmpIOpBooleanPattern, CmpIOpPattern, + CmpFOpNanNonePattern, CmpFOpPattern + >(typeConverter, patterns.getContext()); + // clang-format on + + // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel + // capability is available. + patterns.add(typeConverter, patterns.getContext(), + /*benefit=*/2); +} +} // end namespace mlir diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/ArithmeticToSPIRV/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_conversion_library(MLIRArithmeticToSPIRV + ArithmeticToSPIRV.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithmeticToSPIRV + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRSPIRVConversion + MLIRSPIRV + ) diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -12,6 +12,7 @@ #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -903,9 +904,9 @@ LogicalResult matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto count = - rewriter.create(op->getLoc(), rewriter.getI64Type(), - rewriter.getI64IntegerAttr(op.count())); + auto count = rewriter.create( + op->getLoc(), rewriter.getI64Type(), + rewriter.getI64IntegerAttr(op.count())); auto operand = adaptor.operand(); rewriter.replaceOpWithNewOp(op, TypeRange(), apiFunctionName, @@ -1008,7 +1009,8 @@ converter, ctx); ConversionTarget target(*ctx); - target.addLegalOp(); + target + .addLegalOp(); target.addLegalDialect(); // All operations from Async dialect must be lowered to the runtime API and diff --git a/mlir/lib/Conversion/AsyncToLLVM/CMakeLists.txt b/mlir/lib/Conversion/AsyncToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/AsyncToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/AsyncToLLVM/CMakeLists.txt @@ -11,6 +11,7 @@ Core LINK_LIBS PUBLIC + MLIRArithmetic MLIRAsync MLIRLLVMCommonConversion MLIRLLVMIR diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -1,4 +1,6 @@ add_subdirectory(AffineToStandard) +add_subdirectory(ArithmeticToLLVM) +add_subdirectory(ArithmeticToSPIRV) add_subdirectory(ArmNeon2dToIntr) add_subdirectory(AsyncToLLVM) add_subdirectory(ComplexToLLVM) diff --git a/mlir/lib/Conversion/ComplexToStandard/CMakeLists.txt b/mlir/lib/Conversion/ComplexToStandard/CMakeLists.txt --- a/mlir/lib/Conversion/ComplexToStandard/CMakeLists.txt +++ b/mlir/lib/Conversion/ComplexToStandard/CMakeLists.txt @@ -8,6 +8,7 @@ MLIRConversionPassIncGen LINK_LIBS PUBLIC + MLIRArithmetic MLIRComplex MLIRIR MLIRMath diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -12,6 +12,7 @@ #include #include "../PassDetail.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -33,21 +34,21 @@ Value real = rewriter.create(loc, type, adaptor.complex()); Value imag = rewriter.create(loc, type, adaptor.complex()); - Value realSqr = rewriter.create(loc, real, real); - Value imagSqr = rewriter.create(loc, imag, imag); - Value sqNorm = rewriter.create(loc, realSqr, imagSqr); + Value realSqr = rewriter.create(loc, real, real); + Value imagSqr = rewriter.create(loc, imag, imag); + Value sqNorm = rewriter.create(loc, realSqr, imagSqr); rewriter.replaceOpWithNewOp(op, sqNorm); return success(); } }; -template +template struct ComparisonOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; using ResultCombiner = std::conditional_t::value, - AndOp, OrOp>; + arith::AndIOp, arith::OrIOp>; LogicalResult matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor, @@ -60,8 +61,10 @@ Value imagLhs = rewriter.create(loc, type, adaptor.lhs()); Value realRhs = rewriter.create(loc, type, adaptor.rhs()); Value imagRhs = rewriter.create(loc, type, adaptor.rhs()); - Value realComparison = rewriter.create(loc, p, realLhs, realRhs); - Value imagComparison = rewriter.create(loc, p, imagLhs, imagRhs); + Value realComparison = + rewriter.create(loc, p, realLhs, realRhs); + Value imagComparison = + rewriter.create(loc, p, imagLhs, imagRhs); rewriter.replaceOpWithNewOp(op, realComparison, imagComparison); @@ -138,139 +141,150 @@ // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom // // See https://dl.acm.org/citation.cfm?id=368661 for more details. - Value rhsRealImagRatio = rewriter.create(loc, rhsReal, rhsImag); - Value rhsRealImagDenom = rewriter.create( - loc, rhsImag, rewriter.create(loc, rhsRealImagRatio, rhsReal)); - Value realNumerator1 = rewriter.create( - loc, rewriter.create(loc, lhsReal, rhsRealImagRatio), lhsImag); + Value rhsRealImagRatio = + rewriter.create(loc, rhsReal, rhsImag); + Value rhsRealImagDenom = rewriter.create( + loc, rhsImag, + rewriter.create(loc, rhsRealImagRatio, rhsReal)); + Value realNumerator1 = rewriter.create( + loc, rewriter.create(loc, lhsReal, rhsRealImagRatio), + lhsImag); Value resultReal1 = - rewriter.create(loc, realNumerator1, rhsRealImagDenom); - Value imagNumerator1 = rewriter.create( - loc, rewriter.create(loc, lhsImag, rhsRealImagRatio), lhsReal); + rewriter.create(loc, realNumerator1, rhsRealImagDenom); + Value imagNumerator1 = rewriter.create( + loc, rewriter.create(loc, lhsImag, rhsRealImagRatio), + lhsReal); Value resultImag1 = - rewriter.create(loc, imagNumerator1, rhsRealImagDenom); - - Value rhsImagRealRatio = rewriter.create(loc, rhsImag, rhsReal); - Value rhsImagRealDenom = rewriter.create( - loc, rhsReal, rewriter.create(loc, rhsImagRealRatio, rhsImag)); - Value realNumerator2 = rewriter.create( - loc, lhsReal, rewriter.create(loc, lhsImag, rhsImagRealRatio)); + rewriter.create(loc, imagNumerator1, rhsRealImagDenom); + + Value rhsImagRealRatio = + rewriter.create(loc, rhsImag, rhsReal); + Value rhsImagRealDenom = rewriter.create( + loc, rhsReal, + rewriter.create(loc, rhsImagRealRatio, rhsImag)); + Value realNumerator2 = rewriter.create( + loc, lhsReal, + rewriter.create(loc, lhsImag, rhsImagRealRatio)); Value resultReal2 = - rewriter.create(loc, realNumerator2, rhsImagRealDenom); - Value imagNumerator2 = rewriter.create( - loc, lhsImag, rewriter.create(loc, lhsReal, rhsImagRealRatio)); + rewriter.create(loc, realNumerator2, rhsImagRealDenom); + Value imagNumerator2 = rewriter.create( + loc, lhsImag, + rewriter.create(loc, lhsReal, rhsImagRealRatio)); Value resultImag2 = - rewriter.create(loc, imagNumerator2, rhsImagRealDenom); + rewriter.create(loc, imagNumerator2, rhsImagRealDenom); // Consider corner cases. // Case 1. Zero denominator, numerator contains at most one NaN value. - Value zero = rewriter.create(loc, elementType, - rewriter.getZeroAttr(elementType)); - Value rhsRealAbs = rewriter.create(loc, rhsReal); - Value rhsRealIsZero = - rewriter.create(loc, CmpFPredicate::OEQ, rhsRealAbs, zero); - Value rhsImagAbs = rewriter.create(loc, rhsImag); - Value rhsImagIsZero = - rewriter.create(loc, CmpFPredicate::OEQ, rhsImagAbs, zero); - Value lhsRealIsNotNaN = - rewriter.create(loc, CmpFPredicate::ORD, lhsReal, zero); - Value lhsImagIsNotNaN = - rewriter.create(loc, CmpFPredicate::ORD, lhsImag, zero); + Value zero = rewriter.create( + loc, elementType, rewriter.getZeroAttr(elementType)); + Value rhsRealAbs = rewriter.create(loc, rhsReal); + Value rhsRealIsZero = rewriter.create( + loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero); + Value rhsImagAbs = rewriter.create(loc, rhsImag); + Value rhsImagIsZero = rewriter.create( + loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero); + Value lhsRealIsNotNaN = rewriter.create( + loc, arith::CmpFPredicate::ORD, lhsReal, zero); + Value lhsImagIsNotNaN = rewriter.create( + loc, arith::CmpFPredicate::ORD, lhsImag, zero); Value lhsContainsNotNaNValue = - rewriter.create(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); - Value resultIsInfinity = rewriter.create( + rewriter.create(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); + Value resultIsInfinity = rewriter.create( loc, lhsContainsNotNaNValue, - rewriter.create(loc, rhsRealIsZero, rhsImagIsZero)); - Value inf = rewriter.create( + rewriter.create(loc, rhsRealIsZero, rhsImagIsZero)); + Value inf = rewriter.create( loc, elementType, rewriter.getFloatAttr( elementType, APFloat::getInf(elementType.getFloatSemantics()))); - Value infWithSignOfRhsReal = rewriter.create(loc, inf, rhsReal); + Value infWithSignOfRhsReal = + rewriter.create(loc, inf, rhsReal); Value infinityResultReal = - rewriter.create(loc, infWithSignOfRhsReal, lhsReal); + rewriter.create(loc, infWithSignOfRhsReal, lhsReal); Value infinityResultImag = - rewriter.create(loc, infWithSignOfRhsReal, lhsImag); + rewriter.create(loc, infWithSignOfRhsReal, lhsImag); // Case 2. Infinite numerator, finite denominator. - Value rhsRealFinite = - rewriter.create(loc, CmpFPredicate::ONE, rhsRealAbs, inf); - Value rhsImagFinite = - rewriter.create(loc, CmpFPredicate::ONE, rhsImagAbs, inf); - Value rhsFinite = rewriter.create(loc, rhsRealFinite, rhsImagFinite); - Value lhsRealAbs = rewriter.create(loc, lhsReal); - Value lhsRealInfinite = - rewriter.create(loc, CmpFPredicate::OEQ, lhsRealAbs, inf); - Value lhsImagAbs = rewriter.create(loc, lhsImag); - Value lhsImagInfinite = - rewriter.create(loc, CmpFPredicate::OEQ, lhsImagAbs, inf); + Value rhsRealFinite = rewriter.create( + loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf); + Value rhsImagFinite = rewriter.create( + loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf); + Value rhsFinite = + rewriter.create(loc, rhsRealFinite, rhsImagFinite); + Value lhsRealAbs = rewriter.create(loc, lhsReal); + Value lhsRealInfinite = rewriter.create( + loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf); + Value lhsImagAbs = rewriter.create(loc, lhsImag); + Value lhsImagInfinite = rewriter.create( + loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf); Value lhsInfinite = - rewriter.create(loc, lhsRealInfinite, lhsImagInfinite); + rewriter.create(loc, lhsRealInfinite, lhsImagInfinite); Value infNumFiniteDenom = - rewriter.create(loc, lhsInfinite, rhsFinite); - Value one = rewriter.create( + rewriter.create(loc, lhsInfinite, rhsFinite); + Value one = rewriter.create( loc, elementType, rewriter.getFloatAttr(elementType, 1)); - Value lhsRealIsInfWithSign = rewriter.create( + Value lhsRealIsInfWithSign = rewriter.create( loc, rewriter.create(loc, lhsRealInfinite, one, zero), lhsReal); - Value lhsImagIsInfWithSign = rewriter.create( + Value lhsImagIsInfWithSign = rewriter.create( loc, rewriter.create(loc, lhsImagInfinite, one, zero), lhsImag); Value lhsRealIsInfWithSignTimesRhsReal = - rewriter.create(loc, lhsRealIsInfWithSign, rhsReal); + rewriter.create(loc, lhsRealIsInfWithSign, rhsReal); Value lhsImagIsInfWithSignTimesRhsImag = - rewriter.create(loc, lhsImagIsInfWithSign, rhsImag); - Value resultReal3 = rewriter.create( + rewriter.create(loc, lhsImagIsInfWithSign, rhsImag); + Value resultReal3 = rewriter.create( loc, inf, - rewriter.create(loc, lhsRealIsInfWithSignTimesRhsReal, - lhsImagIsInfWithSignTimesRhsImag)); + rewriter.create(loc, lhsRealIsInfWithSignTimesRhsReal, + lhsImagIsInfWithSignTimesRhsImag)); Value lhsRealIsInfWithSignTimesRhsImag = - rewriter.create(loc, lhsRealIsInfWithSign, rhsImag); + rewriter.create(loc, lhsRealIsInfWithSign, rhsImag); Value lhsImagIsInfWithSignTimesRhsReal = - rewriter.create(loc, lhsImagIsInfWithSign, rhsReal); - Value resultImag3 = rewriter.create( + rewriter.create(loc, lhsImagIsInfWithSign, rhsReal); + Value resultImag3 = rewriter.create( loc, inf, - rewriter.create(loc, lhsImagIsInfWithSignTimesRhsReal, - lhsRealIsInfWithSignTimesRhsImag)); + rewriter.create(loc, lhsImagIsInfWithSignTimesRhsReal, + lhsRealIsInfWithSignTimesRhsImag)); // Case 3: Finite numerator, infinite denominator. - Value lhsRealFinite = - rewriter.create(loc, CmpFPredicate::ONE, lhsRealAbs, inf); - Value lhsImagFinite = - rewriter.create(loc, CmpFPredicate::ONE, lhsImagAbs, inf); - Value lhsFinite = rewriter.create(loc, lhsRealFinite, lhsImagFinite); - Value rhsRealInfinite = - rewriter.create(loc, CmpFPredicate::OEQ, rhsRealAbs, inf); - Value rhsImagInfinite = - rewriter.create(loc, CmpFPredicate::OEQ, rhsImagAbs, inf); + Value lhsRealFinite = rewriter.create( + loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf); + Value lhsImagFinite = rewriter.create( + loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf); + Value lhsFinite = + rewriter.create(loc, lhsRealFinite, lhsImagFinite); + Value rhsRealInfinite = rewriter.create( + loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf); + Value rhsImagInfinite = rewriter.create( + loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf); Value rhsInfinite = - rewriter.create(loc, rhsRealInfinite, rhsImagInfinite); + rewriter.create(loc, rhsRealInfinite, rhsImagInfinite); Value finiteNumInfiniteDenom = - rewriter.create(loc, lhsFinite, rhsInfinite); - Value rhsRealIsInfWithSign = rewriter.create( + rewriter.create(loc, lhsFinite, rhsInfinite); + Value rhsRealIsInfWithSign = rewriter.create( loc, rewriter.create(loc, rhsRealInfinite, one, zero), rhsReal); - Value rhsImagIsInfWithSign = rewriter.create( + Value rhsImagIsInfWithSign = rewriter.create( loc, rewriter.create(loc, rhsImagInfinite, one, zero), rhsImag); Value rhsRealIsInfWithSignTimesLhsReal = - rewriter.create(loc, lhsReal, rhsRealIsInfWithSign); + rewriter.create(loc, lhsReal, rhsRealIsInfWithSign); Value rhsImagIsInfWithSignTimesLhsImag = - rewriter.create(loc, lhsImag, rhsImagIsInfWithSign); - Value resultReal4 = rewriter.create( + rewriter.create(loc, lhsImag, rhsImagIsInfWithSign); + Value resultReal4 = rewriter.create( loc, zero, - rewriter.create(loc, rhsRealIsInfWithSignTimesLhsReal, - rhsImagIsInfWithSignTimesLhsImag)); + rewriter.create(loc, rhsRealIsInfWithSignTimesLhsReal, + rhsImagIsInfWithSignTimesLhsImag)); Value rhsRealIsInfWithSignTimesLhsImag = - rewriter.create(loc, lhsImag, rhsRealIsInfWithSign); + rewriter.create(loc, lhsImag, rhsRealIsInfWithSign); Value rhsImagIsInfWithSignTimesLhsReal = - rewriter.create(loc, lhsReal, rhsImagIsInfWithSign); - Value resultImag4 = rewriter.create( + rewriter.create(loc, lhsReal, rhsImagIsInfWithSign); + Value resultImag4 = rewriter.create( loc, zero, - rewriter.create(loc, rhsRealIsInfWithSignTimesLhsImag, - rhsImagIsInfWithSignTimesLhsReal)); + rewriter.create(loc, rhsRealIsInfWithSignTimesLhsImag, + rhsImagIsInfWithSignTimesLhsReal)); - Value realAbsSmallerThanImagAbs = rewriter.create( - loc, CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); + Value realAbsSmallerThanImagAbs = rewriter.create( + loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); Value resultReal = rewriter.create(loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); Value resultImag = rewriter.create(loc, realAbsSmallerThanImagAbs, @@ -288,12 +302,12 @@ Value resultImagSpecialCase1 = rewriter.create( loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); - Value resultRealIsNaN = - rewriter.create(loc, CmpFPredicate::UNO, resultReal, zero); - Value resultImagIsNaN = - rewriter.create(loc, CmpFPredicate::UNO, resultImag, zero); + Value resultRealIsNaN = rewriter.create( + loc, arith::CmpFPredicate::UNO, resultReal, zero); + Value resultImagIsNaN = rewriter.create( + loc, arith::CmpFPredicate::UNO, resultImag, zero); Value resultIsNaN = - rewriter.create(loc, resultRealIsNaN, resultImagIsNaN); + rewriter.create(loc, resultRealIsNaN, resultImagIsNaN); Value resultRealWithSpecialCases = rewriter.create( loc, resultIsNaN, resultRealSpecialCase1, resultReal); Value resultImagWithSpecialCases = rewriter.create( @@ -321,9 +335,9 @@ rewriter.create(loc, elementType, adaptor.complex()); Value expReal = rewriter.create(loc, real); Value cosImag = rewriter.create(loc, imag); - Value resultReal = rewriter.create(loc, expReal, cosImag); + Value resultReal = rewriter.create(loc, expReal, cosImag); Value sinImag = rewriter.create(loc, imag); - Value resultImag = rewriter.create(loc, expReal, sinImag); + Value resultImag = rewriter.create(loc, expReal, sinImag); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); @@ -364,9 +378,9 @@ Value real = b.create(elementType, adaptor.complex()); Value imag = b.create(elementType, adaptor.complex()); - Value one = - b.create(elementType, b.getFloatAttr(elementType, 1)); - Value realPlusOne = b.create(real, one); + Value one = b.create(elementType, + b.getFloatAttr(elementType, 1)); + Value realPlusOne = b.create(real, one); Value newComplex = b.create(type, realPlusOne, imag); rewriter.replaceOpWithNewOp(op, type, newComplex); return success(); @@ -384,126 +398,162 @@ auto elementType = type.getElementType().cast(); Value lhsReal = b.create(elementType, adaptor.lhs()); - Value lhsRealAbs = b.create(lhsReal); + Value lhsRealAbs = b.create(lhsReal); Value lhsImag = b.create(elementType, adaptor.lhs()); - Value lhsImagAbs = b.create(lhsImag); + Value lhsImagAbs = b.create(lhsImag); Value rhsReal = b.create(elementType, adaptor.rhs()); - Value rhsRealAbs = b.create(rhsReal); + Value rhsRealAbs = b.create(rhsReal); Value rhsImag = b.create(elementType, adaptor.rhs()); - Value rhsImagAbs = b.create(rhsImag); + Value rhsImagAbs = b.create(rhsImag); - Value lhsRealTimesRhsReal = b.create(lhsReal, rhsReal); - Value lhsRealTimesRhsRealAbs = b.create(lhsRealTimesRhsReal); - Value lhsImagTimesRhsImag = b.create(lhsImag, rhsImag); - Value lhsImagTimesRhsImagAbs = b.create(lhsImagTimesRhsImag); - Value real = b.create(lhsRealTimesRhsReal, lhsImagTimesRhsImag); + Value lhsRealTimesRhsReal = b.create(lhsReal, rhsReal); + Value lhsRealTimesRhsRealAbs = b.create(lhsRealTimesRhsReal); + Value lhsImagTimesRhsImag = b.create(lhsImag, rhsImag); + Value lhsImagTimesRhsImagAbs = b.create(lhsImagTimesRhsImag); + Value real = + b.create(lhsRealTimesRhsReal, lhsImagTimesRhsImag); - Value lhsImagTimesRhsReal = b.create(lhsImag, rhsReal); - Value lhsImagTimesRhsRealAbs = b.create(lhsImagTimesRhsReal); - Value lhsRealTimesRhsImag = b.create(lhsReal, rhsImag); - Value lhsRealTimesRhsImagAbs = b.create(lhsRealTimesRhsImag); - Value imag = b.create(lhsImagTimesRhsReal, lhsRealTimesRhsImag); + Value lhsImagTimesRhsReal = b.create(lhsImag, rhsReal); + Value lhsImagTimesRhsRealAbs = b.create(lhsImagTimesRhsReal); + Value lhsRealTimesRhsImag = b.create(lhsReal, rhsImag); + Value lhsRealTimesRhsImagAbs = b.create(lhsRealTimesRhsImag); + Value imag = + b.create(lhsImagTimesRhsReal, lhsRealTimesRhsImag); // Handle cases where the "naive" calculation results in NaN values. - Value realIsNan = b.create(CmpFPredicate::UNO, real, real); - Value imagIsNan = b.create(CmpFPredicate::UNO, imag, imag); - Value isNan = b.create(realIsNan, imagIsNan); + Value realIsNan = + b.create(arith::CmpFPredicate::UNO, real, real); + Value imagIsNan = + b.create(arith::CmpFPredicate::UNO, imag, imag); + Value isNan = b.create(realIsNan, imagIsNan); - Value inf = b.create( + Value inf = b.create( elementType, b.getFloatAttr(elementType, APFloat::getInf(elementType.getFloatSemantics()))); // Case 1. `lhsReal` or `lhsImag` are infinite. - Value lhsRealIsInf = b.create(CmpFPredicate::OEQ, lhsRealAbs, inf); - Value lhsImagIsInf = b.create(CmpFPredicate::OEQ, lhsImagAbs, inf); - Value lhsIsInf = b.create(lhsRealIsInf, lhsImagIsInf); - Value rhsRealIsNan = b.create(CmpFPredicate::UNO, rhsReal, rhsReal); - Value rhsImagIsNan = b.create(CmpFPredicate::UNO, rhsImag, rhsImag); - Value zero = b.create(elementType, b.getZeroAttr(elementType)); - Value one = - b.create(elementType, b.getFloatAttr(elementType, 1)); + Value lhsRealIsInf = + b.create(arith::CmpFPredicate::OEQ, lhsRealAbs, inf); + Value lhsImagIsInf = + b.create(arith::CmpFPredicate::OEQ, lhsImagAbs, inf); + Value lhsIsInf = b.create(lhsRealIsInf, lhsImagIsInf); + Value rhsRealIsNan = + b.create(arith::CmpFPredicate::UNO, rhsReal, rhsReal); + Value rhsImagIsNan = + b.create(arith::CmpFPredicate::UNO, rhsImag, rhsImag); + Value zero = + b.create(elementType, b.getZeroAttr(elementType)); + Value one = b.create(elementType, + b.getFloatAttr(elementType, 1)); Value lhsRealIsInfFloat = b.create(lhsRealIsInf, one, zero); lhsReal = b.create( - lhsIsInf, b.create(lhsRealIsInfFloat, lhsReal), lhsReal); + lhsIsInf, b.create(lhsRealIsInfFloat, lhsReal), + lhsReal); Value lhsImagIsInfFloat = b.create(lhsImagIsInf, one, zero); lhsImag = b.create( - lhsIsInf, b.create(lhsImagIsInfFloat, lhsImag), lhsImag); - Value lhsIsInfAndRhsRealIsNan = b.create(lhsIsInf, rhsRealIsNan); - rhsReal = b.create(lhsIsInfAndRhsRealIsNan, - b.create(zero, rhsReal), rhsReal); - Value lhsIsInfAndRhsImagIsNan = b.create(lhsIsInf, rhsImagIsNan); - rhsImag = b.create(lhsIsInfAndRhsImagIsNan, - b.create(zero, rhsImag), rhsImag); + lhsIsInf, b.create(lhsImagIsInfFloat, lhsImag), + lhsImag); + Value lhsIsInfAndRhsRealIsNan = + b.create(lhsIsInf, rhsRealIsNan); + rhsReal = + b.create(lhsIsInfAndRhsRealIsNan, + b.create(zero, rhsReal), rhsReal); + Value lhsIsInfAndRhsImagIsNan = + b.create(lhsIsInf, rhsImagIsNan); + rhsImag = + b.create(lhsIsInfAndRhsImagIsNan, + b.create(zero, rhsImag), rhsImag); // Case 2. `rhsReal` or `rhsImag` are infinite. - Value rhsRealIsInf = b.create(CmpFPredicate::OEQ, rhsRealAbs, inf); - Value rhsImagIsInf = b.create(CmpFPredicate::OEQ, rhsImagAbs, inf); - Value rhsIsInf = b.create(rhsRealIsInf, rhsImagIsInf); - Value lhsRealIsNan = b.create(CmpFPredicate::UNO, lhsReal, lhsReal); - Value lhsImagIsNan = b.create(CmpFPredicate::UNO, lhsImag, lhsImag); + Value rhsRealIsInf = + b.create(arith::CmpFPredicate::OEQ, rhsRealAbs, inf); + Value rhsImagIsInf = + b.create(arith::CmpFPredicate::OEQ, rhsImagAbs, inf); + Value rhsIsInf = b.create(rhsRealIsInf, rhsImagIsInf); + Value lhsRealIsNan = + b.create(arith::CmpFPredicate::UNO, lhsReal, lhsReal); + Value lhsImagIsNan = + b.create(arith::CmpFPredicate::UNO, lhsImag, lhsImag); Value rhsRealIsInfFloat = b.create(rhsRealIsInf, one, zero); rhsReal = b.create( - rhsIsInf, b.create(rhsRealIsInfFloat, rhsReal), rhsReal); + rhsIsInf, b.create(rhsRealIsInfFloat, rhsReal), + rhsReal); Value rhsImagIsInfFloat = b.create(rhsImagIsInf, one, zero); rhsImag = b.create( - rhsIsInf, b.create(rhsImagIsInfFloat, rhsImag), rhsImag); - Value rhsIsInfAndLhsRealIsNan = b.create(rhsIsInf, lhsRealIsNan); - lhsReal = b.create(rhsIsInfAndLhsRealIsNan, - b.create(zero, lhsReal), lhsReal); - Value rhsIsInfAndLhsImagIsNan = b.create(rhsIsInf, lhsImagIsNan); - lhsImag = b.create(rhsIsInfAndLhsImagIsNan, - b.create(zero, lhsImag), lhsImag); - Value recalc = b.create(lhsIsInf, rhsIsInf); + rhsIsInf, b.create(rhsImagIsInfFloat, rhsImag), + rhsImag); + Value rhsIsInfAndLhsRealIsNan = + b.create(rhsIsInf, lhsRealIsNan); + lhsReal = + b.create(rhsIsInfAndLhsRealIsNan, + b.create(zero, lhsReal), lhsReal); + Value rhsIsInfAndLhsImagIsNan = + b.create(rhsIsInf, lhsImagIsNan); + lhsImag = + b.create(rhsIsInfAndLhsImagIsNan, + b.create(zero, lhsImag), lhsImag); + Value recalc = b.create(lhsIsInf, rhsIsInf); // Case 3. One of the pairwise products of left hand side with right hand // side is infinite. - Value lhsRealTimesRhsRealIsInf = - b.create(CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf); - Value lhsImagTimesRhsImagIsInf = - b.create(CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf); - Value isSpecialCase = - b.create(lhsRealTimesRhsRealIsInf, lhsImagTimesRhsImagIsInf); - Value lhsRealTimesRhsImagIsInf = - b.create(CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf); - isSpecialCase = b.create(isSpecialCase, lhsRealTimesRhsImagIsInf); - Value lhsImagTimesRhsRealIsInf = - b.create(CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf); - isSpecialCase = b.create(isSpecialCase, lhsImagTimesRhsRealIsInf); + Value lhsRealTimesRhsRealIsInf = b.create( + arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf); + Value lhsImagTimesRhsImagIsInf = b.create( + arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf); + Value isSpecialCase = b.create(lhsRealTimesRhsRealIsInf, + lhsImagTimesRhsImagIsInf); + Value lhsRealTimesRhsImagIsInf = b.create( + arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf); + isSpecialCase = + b.create(isSpecialCase, lhsRealTimesRhsImagIsInf); + Value lhsImagTimesRhsRealIsInf = b.create( + arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf); + isSpecialCase = + b.create(isSpecialCase, lhsImagTimesRhsRealIsInf); Type i1Type = b.getI1Type(); - Value notRecalc = b.create( - recalc, b.create(i1Type, b.getIntegerAttr(i1Type, 1))); - isSpecialCase = b.create(isSpecialCase, notRecalc); + Value notRecalc = b.create( + recalc, + b.create(i1Type, b.getIntegerAttr(i1Type, 1))); + isSpecialCase = b.create(isSpecialCase, notRecalc); Value isSpecialCaseAndLhsRealIsNan = - b.create(isSpecialCase, lhsRealIsNan); - lhsReal = b.create(isSpecialCaseAndLhsRealIsNan, - b.create(zero, lhsReal), lhsReal); + b.create(isSpecialCase, lhsRealIsNan); + lhsReal = + b.create(isSpecialCaseAndLhsRealIsNan, + b.create(zero, lhsReal), lhsReal); Value isSpecialCaseAndLhsImagIsNan = - b.create(isSpecialCase, lhsImagIsNan); - lhsImag = b.create(isSpecialCaseAndLhsImagIsNan, - b.create(zero, lhsImag), lhsImag); + b.create(isSpecialCase, lhsImagIsNan); + lhsImag = + b.create(isSpecialCaseAndLhsImagIsNan, + b.create(zero, lhsImag), lhsImag); Value isSpecialCaseAndRhsRealIsNan = - b.create(isSpecialCase, rhsRealIsNan); - rhsReal = b.create(isSpecialCaseAndRhsRealIsNan, - b.create(zero, rhsReal), rhsReal); + b.create(isSpecialCase, rhsRealIsNan); + rhsReal = + b.create(isSpecialCaseAndRhsRealIsNan, + b.create(zero, rhsReal), rhsReal); Value isSpecialCaseAndRhsImagIsNan = - b.create(isSpecialCase, rhsImagIsNan); - rhsImag = b.create(isSpecialCaseAndRhsImagIsNan, - b.create(zero, rhsImag), rhsImag); - recalc = b.create(recalc, isSpecialCase); - recalc = b.create(isNan, recalc); + b.create(isSpecialCase, rhsImagIsNan); + rhsImag = + b.create(isSpecialCaseAndRhsImagIsNan, + b.create(zero, rhsImag), rhsImag); + recalc = b.create(recalc, isSpecialCase); + recalc = b.create(isNan, recalc); // Recalculate real part. - lhsRealTimesRhsReal = b.create(lhsReal, rhsReal); - lhsImagTimesRhsImag = b.create(lhsImag, rhsImag); - Value newReal = b.create(lhsRealTimesRhsReal, lhsImagTimesRhsImag); - real = b.create(recalc, b.create(inf, newReal), real); + lhsRealTimesRhsReal = b.create(lhsReal, rhsReal); + lhsImagTimesRhsImag = b.create(lhsImag, rhsImag); + Value newReal = + b.create(lhsRealTimesRhsReal, lhsImagTimesRhsImag); + real = + b.create(recalc, b.create(inf, newReal), real); // Recalculate imag part. - lhsImagTimesRhsReal = b.create(lhsImag, rhsReal); - lhsRealTimesRhsImag = b.create(lhsReal, rhsImag); - Value newImag = b.create(lhsImagTimesRhsReal, lhsRealTimesRhsImag); - imag = b.create(recalc, b.create(inf, newImag), imag); + lhsImagTimesRhsReal = b.create(lhsImag, rhsReal); + lhsRealTimesRhsImag = b.create(lhsReal, rhsImag); + Value newImag = + b.create(lhsImagTimesRhsReal, lhsRealTimesRhsImag); + imag = + b.create(recalc, b.create(inf, newImag), imag); rewriter.replaceOpWithNewOp(op, type, real, imag); return success(); @@ -524,8 +574,8 @@ rewriter.create(loc, elementType, adaptor.complex()); Value imag = rewriter.create(loc, elementType, adaptor.complex()); - Value negReal = rewriter.create(loc, real); - Value negImag = rewriter.create(loc, imag); + Value negReal = rewriter.create(loc, real); + Value negImag = rewriter.create(loc, imag); rewriter.replaceOpWithNewOp(op, type, negReal, negImag); return success(); } @@ -543,13 +593,16 @@ Value real = b.create(elementType, adaptor.complex()); Value imag = b.create(elementType, adaptor.complex()); - Value zero = b.create(elementType, b.getZeroAttr(elementType)); - Value realIsZero = b.create(CmpFPredicate::OEQ, real, zero); - Value imagIsZero = b.create(CmpFPredicate::OEQ, imag, zero); - Value isZero = b.create(realIsZero, imagIsZero); + Value zero = + b.create(elementType, b.getZeroAttr(elementType)); + Value realIsZero = + b.create(arith::CmpFPredicate::OEQ, real, zero); + Value imagIsZero = + b.create(arith::CmpFPredicate::OEQ, imag, zero); + Value isZero = b.create(realIsZero, imagIsZero); auto abs = b.create(elementType, adaptor.complex()); - Value realSign = b.create(real, abs); - Value imagSign = b.create(imag, abs); + Value realSign = b.create(real, abs); + Value imagSign = b.create(imag, abs); Value sign = b.create(type, realSign, imagSign); rewriter.replaceOpWithNewOp(op, isZero, adaptor.complex(), sign); return success(); @@ -562,10 +615,10 @@ // clang-format off patterns.add< AbsOpConversion, - ComparisonOpConversion, - ComparisonOpConversion, - BinaryComplexOpConversion, - BinaryComplexOpConversion, + ComparisonOpConversion, + ComparisonOpConversion, + BinaryComplexOpConversion, + BinaryComplexOpConversion, DivOpConversion, ExpOpConversion, LogOpConversion, @@ -590,7 +643,8 @@ populateComplexToStandardConversionPatterns(patterns); ConversionTarget target(getContext()); - target.addLegalDialect(); + target.addLegalDialect(); target.addLegalOp(); if (failed(applyPartialConversion(function, target, std::move(patterns)))) signalPassFailure(); diff --git a/mlir/lib/Conversion/GPUCommon/CMakeLists.txt b/mlir/lib/Conversion/GPUCommon/CMakeLists.txt --- a/mlir/lib/Conversion/GPUCommon/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUCommon/CMakeLists.txt @@ -29,6 +29,7 @@ ${NVPTX_LIBS} LINK_LIBS PUBLIC + MLIRArithmeticToLLVM MLIRAsyncToLLVM MLIRGPUTransforms MLIRIR diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -16,6 +16,7 @@ #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "../PassDetail.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" @@ -349,6 +350,7 @@ target.addIllegalDialect(); + populateArithmeticToLLVMConversionPatterns(converter, patterns); populateVectorToLLVMConversionPatterns(converter, patterns); populateMemRefToLLVMConversionPatterns(converter, patterns); populateStdToLLVMConversionPatterns(converter, patterns); diff --git a/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt b/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt --- a/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt @@ -11,11 +11,11 @@ MLIRGPUToNVVMIncGen LINK_LIBS PUBLIC + MLIRArithmeticToLLVM MLIRGPUOps MLIRGPUToGPURuntimeTransforms MLIRLLVMCommonConversion MLIRLLVMIR - MLIRMemRef MLIRMemRefToLLVM MLIRNVVMIR MLIRPass diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -13,11 +13,13 @@ #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" @@ -169,6 +171,7 @@ populateGpuRewritePatterns(patterns); (void)applyPatternsAndFoldGreedily(m, std::move(patterns)); + populateArithmeticToLLVMConversionPatterns(converter, llvmPatterns); populateStdToLLVMConversionPatterns(converter, llvmPatterns); populateMemRefToLLVMConversionPatterns(converter, llvmPatterns); populateGpuToNVVMConversionPatterns(converter, llvmPatterns); @@ -217,14 +220,14 @@ Identifier::get(NVVM::NVVMDialect::getKernelFuncAttrName(), &converter.getContext())); - patterns.add>(converter, "__nv_fabsf", - "__nv_fabs"); + patterns.add>(converter, "__nv_fabsf", + "__nv_fabs"); patterns.add>(converter, "__nv_atanf", "__nv_atan"); patterns.add>(converter, "__nv_atan2f", "__nv_atan2"); - patterns.add>(converter, "__nv_ceilf", - "__nv_ceil"); + patterns.add>(converter, "__nv_ceilf", + "__nv_ceil"); patterns.add>(converter, "__nv_cosf", "__nv_cos"); patterns.add>(converter, "__nv_expf", @@ -233,8 +236,8 @@ "__nv_exp2"); patterns.add>(converter, "__nv_expm1f", "__nv_expm1"); - patterns.add>(converter, "__nv_floorf", - "__nv_floor"); + patterns.add>(converter, "__nv_floorf", + "__nv_floor"); patterns.add>(converter, "__nv_logf", "__nv_log"); patterns.add>(converter, "__nv_log1pf", diff --git a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt --- a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt @@ -10,6 +10,7 @@ MLIRGPUToROCDLIncGen LINK_LIBS PUBLIC + MLIRArithmeticToLLVM MLIRGPUOps MLIRGPUToGPURuntimeTransforms MLIRLLVMCommonConversion diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" @@ -72,6 +73,7 @@ populateGpuRewritePatterns(patterns); (void)applyPatternsAndFoldGreedily(m, std::move(patterns)); + populateArithmeticToLLVMConversionPatterns(converter, llvmPatterns); populateVectorToLLVMConversionPatterns(converter, llvmPatterns); populateVectorToROCDLConversionPatterns(converter, llvmPatterns); populateStdToLLVMConversionPatterns(converter, llvmPatterns); @@ -116,14 +118,14 @@ converter, /*allocaAddrSpace=*/5, Identifier::get(ROCDL::ROCDLDialect::getKernelFuncAttrName(), &converter.getContext())); - patterns.add>(converter, "__ocml_fabs_f32", - "__ocml_fabs_f64"); + patterns.add>(converter, "__ocml_fabs_f32", + "__ocml_fabs_f64"); patterns.add>(converter, "__ocml_atan_f32", "__ocml_atan_f64"); patterns.add>( converter, "__ocml_atan2_f32", "__ocml_atan2_f64"); - patterns.add>(converter, "__ocml_ceil_f32", - "__ocml_ceil_f64"); + patterns.add>(converter, "__ocml_ceil_f32", + "__ocml_ceil_f64"); patterns.add>(converter, "__ocml_cos_f32", "__ocml_cos_f64"); patterns.add>(converter, "__ocml_exp_f32", @@ -132,8 +134,8 @@ "__ocml_exp2_f64"); patterns.add>( converter, "__ocml_expm1_f32", "__ocml_expm1_f64"); - patterns.add>(converter, "__ocml_floor_f32", - "__ocml_floor_f64"); + patterns.add>( + converter, "__ocml_floor_f32", "__ocml_floor_f64"); patterns.add>(converter, "__ocml_log_f32", "__ocml_log_f64"); patterns.add>( diff --git a/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt --- a/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt @@ -6,13 +6,13 @@ MLIRConversionPassIncGen LINK_LIBS PUBLIC + MLIRArithmeticToSPIRV MLIRGPUOps MLIRIR MLIRPass MLIRSCFToSPIRV MLIRSPIRV MLIRSPIRVConversion - MLIRStandard MLIRStandardToSPIRV MLIRSupport MLIRTransforms diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h" #include "../PassDetail.h" +#include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h" #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h" #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h" @@ -63,6 +64,7 @@ // TODO: Change SPIR-V conversion to be progressive and remove the following // patterns. + populateArithmeticToSPIRVPatterns(typeConverter, patterns); populateMemRefToSPIRVPatterns(typeConverter, patterns); populateStandardToSPIRVPatterns(typeConverter, patterns); diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -184,7 +184,8 @@ void ConvertLinalgToStandardPass::runOnOperation() { auto module = getOperation(); ConversionTarget target(getContext()); - target.addLegalDialect(); target.addLegalOp(); RewritePatternSet patterns(&getContext()); diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -18,9 +18,16 @@ using namespace mlir; namespace { +using AbsOpLowering = VectorConvertToLLVMPattern; +using CeilOpLowering = VectorConvertToLLVMPattern; +using CopySignOpLowering = + VectorConvertToLLVMPattern; using CosOpLowering = VectorConvertToLLVMPattern; using ExpOpLowering = VectorConvertToLLVMPattern; using Exp2OpLowering = VectorConvertToLLVMPattern; +using FloorOpLowering = + VectorConvertToLLVMPattern; +using FmaOpLowering = VectorConvertToLLVMPattern; using Log10OpLowering = VectorConvertToLLVMPattern; using Log2OpLowering = VectorConvertToLLVMPattern; @@ -209,10 +216,15 @@ RewritePatternSet &patterns) { // clang-format off patterns.add< + AbsOpLowering, + CeilOpLowering, + CopySignOpLowering, CosOpLowering, ExpOpLowering, Exp2OpLowering, ExpM1OpLowering, + FloorOpLowering, + FmaOpLowering, Log10OpLowering, Log1pOpLowering, Log2OpLowering, diff --git a/mlir/lib/Conversion/MathToLibm/CMakeLists.txt b/mlir/lib/Conversion/MathToLibm/CMakeLists.txt --- a/mlir/lib/Conversion/MathToLibm/CMakeLists.txt +++ b/mlir/lib/Conversion/MathToLibm/CMakeLists.txt @@ -11,6 +11,7 @@ Core LINK_LIBS PUBLIC + MLIRArithmetic MLIRMath MLIRStandardOpsTransforms ) diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/MathToLibm/MathToLibm.h" #include "../PassDetail.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" @@ -61,7 +62,7 @@ if (shape.size() != 1) return failure(); - Value result = rewriter.create( + Value result = rewriter.create( loc, DenseElementsAttr::get( vecType, FloatAttr::get(vecType.getElementType(), 0.0))); for (auto i = 0; i < shape.front(); ++i) { @@ -135,8 +136,8 @@ populateMathToLibmConversionPatterns(patterns, /*benefit=*/1); ConversionTarget target(getContext()); - target.addLegalDialect(); + target.addLegalDialect(); target.addIllegalDialect(); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Conversion/SPIRVCommon/Pattern.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" @@ -29,31 +30,6 @@ // normal RewritePattern. namespace { - -/// Converts unary and binary standard operations to SPIR-V operations. -template -class UnaryAndBinaryOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - assert(adaptor.getOperands().size() <= 2); - auto dstType = this->getTypeConverter()->convertType(operation.getType()); - if (!dstType) - return failure(); - if (SPIRVOp::template hasTrait() && - dstType != operation.getType()) { - return operation.emitError( - "bitwidth emulation is not implemented yet on unsigned op"); - } - rewriter.template replaceOpWithNewOp(operation, dstType, - adaptor.getOperands()); - return success(); - } -}; - /// Converts math.log1p to SPIR-V ops. /// /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to @@ -76,7 +52,6 @@ return success(); } }; - } // namespace //===----------------------------------------------------------------------===// @@ -86,15 +61,19 @@ namespace mlir { void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern>( + patterns.add< + Log1pOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern>( typeConverter, patterns.getContext()); } diff --git a/mlir/lib/Conversion/OpenACCToSCF/CMakeLists.txt b/mlir/lib/Conversion/OpenACCToSCF/CMakeLists.txt --- a/mlir/lib/Conversion/OpenACCToSCF/CMakeLists.txt +++ b/mlir/lib/Conversion/OpenACCToSCF/CMakeLists.txt @@ -8,8 +8,9 @@ MLIRConversionPassIncGen LINK_LIBS PUBLIC + MLIRArithmetic MLIRIR MLIROpenACC - MLIRTransforms MLIRSCF + MLIRTransforms ) diff --git a/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp --- a/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp +++ b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp @@ -8,6 +8,7 @@ #include "../PassDetail.h" #include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -33,7 +34,7 @@ return success(); // Condition is not a constant. - if (!op.ifCond().template getDefiningOp()) { + if (!op.ifCond().template getDefiningOp()) { auto ifOp = rewriter.create(op.getLoc(), TypeRange(), op.ifCond(), false); rewriter.updateRootInPlace(op, [&]() { op.ifCondMutable().erase(0); }); diff --git a/mlir/lib/Conversion/OpenMPToLLVM/CMakeLists.txt b/mlir/lib/Conversion/OpenMPToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/OpenMPToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/OpenMPToLLVM/CMakeLists.txt @@ -12,6 +12,7 @@ Core LINK_LIBS PUBLIC + MLIRArithmeticToLLVM MLIRIR MLIRLLVMCommonConversion MLIRLLVMIR diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "../PassDetail.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" @@ -65,6 +66,7 @@ // Convert to OpenMP operations with LLVM IR dialect RewritePatternSet patterns(&getContext()); LLVMTypeConverter converter(&getContext()); + populateArithmeticToLLVMConversionPatterns(converter, patterns); populateMemRefToLLVMConversionPatterns(converter, patterns); populateStdToLLVMConversionPatterns(converter, patterns); populateOpenMPToLLVMConversionPatterns(converter, patterns); diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h --- a/mlir/lib/Conversion/PassDetail.h +++ b/mlir/lib/Conversion/PassDetail.h @@ -23,6 +23,10 @@ class OpenACCDialect; } // end namespace acc +namespace arith { +class ArithmeticDialect; +} // end namespace arith + namespace complex { class ComplexDialect; } // end namespace complex diff --git a/mlir/lib/Conversion/SCFToGPU/CMakeLists.txt b/mlir/lib/Conversion/SCFToGPU/CMakeLists.txt --- a/mlir/lib/Conversion/SCFToGPU/CMakeLists.txt +++ b/mlir/lib/Conversion/SCFToGPU/CMakeLists.txt @@ -11,6 +11,7 @@ LINK_LIBS PUBLIC MLIRAffine MLIRAffineToStandard + MLIRArithmetic MLIRComplex MLIRGPUTransforms MLIRIR diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -16,6 +16,7 @@ #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/ParallelLoopMapper.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -83,7 +84,8 @@ // Get a Value that corresponds to the loop step. If the step is an attribute, // materialize a corresponding constant using builder. static Value getOrCreateStep(AffineForOp forOp, OpBuilder &builder) { - return builder.create(forOp.getLoc(), forOp.getStep()); + return builder.create(forOp.getLoc(), + forOp.getStep()); } // Get a Value for the loop lower bound. If the value requires computation, @@ -169,8 +171,8 @@ // Return true if the value is obviously a constant "one". static bool isConstantOne(Value value) { - if (auto def = value.getDefiningOp()) - return def.getValue() == 1; + if (auto def = value.getDefiningOp()) + return def.value() == 1; return false; } @@ -194,11 +196,11 @@ return llvm::None; } - Value range = - builder.create(currentLoop.getLoc(), upperBound, lowerBound); + Value range = builder.create(currentLoop.getLoc(), + upperBound, lowerBound); Value step = getOrCreateStep(currentLoop, builder); if (!isConstantOne(step)) - range = builder.create(currentLoop.getLoc(), range, step); + range = builder.create(currentLoop.getLoc(), range, step); dims.push_back(range); lbs.push_back(lowerBound); @@ -222,9 +224,10 @@ OpBuilder builder(rootForOp.getOperation()); // Prepare the grid and block sizes for the launch operation. If there is // no loop mapped to a specific dimension, use constant "1" as its size. - Value constOne = (numBlockDims < 3 || numThreadDims < 3) - ? builder.create(rootForOp.getLoc(), 1) - : nullptr; + Value constOne = + (numBlockDims < 3 || numThreadDims < 3) + ? builder.create(rootForOp.getLoc(), 1) + : nullptr; Value gridSizeX = numBlockDims > 0 ? dims[0] : constOne; Value gridSizeY = numBlockDims > 1 ? dims[1] : constOne; Value gridSizeZ = numBlockDims > 2 ? dims[2] : constOne; @@ -265,10 +268,10 @@ : getDim3Value(launchOp.getThreadIds(), en.index() - numBlockDims); Value step = steps[en.index()]; if (!isConstantOne(step)) - id = builder.create(rootForOp.getLoc(), step, id); + id = builder.create(rootForOp.getLoc(), step, id); Value ivReplacement = - builder.create(rootForOp.getLoc(), *lbArgumentIt, id); + builder.create(rootForOp.getLoc(), *lbArgumentIt, id); en.value().replaceAllUsesWith(ivReplacement); std::advance(lbArgumentIt, 1); std::advance(stepArgumentIt, 1); @@ -314,33 +317,33 @@ /// `upperBound`. static Value deriveStaticUpperBound(Value upperBound, PatternRewriter &rewriter) { - if (auto op = upperBound.getDefiningOp()) { + if (auto op = upperBound.getDefiningOp()) { return op; } if (auto minOp = upperBound.getDefiningOp()) { for (const AffineExpr &result : minOp.map().getResults()) { if (auto constExpr = result.dyn_cast()) { - return rewriter.create(minOp.getLoc(), - constExpr.getValue()); + return rewriter.create(minOp.getLoc(), + constExpr.getValue()); } } } - if (auto multiplyOp = upperBound.getDefiningOp()) { - if (auto lhs = dyn_cast_or_null( + if (auto multiplyOp = upperBound.getDefiningOp()) { + if (auto lhs = dyn_cast_or_null( deriveStaticUpperBound(multiplyOp.getOperand(0), rewriter) .getDefiningOp())) - if (auto rhs = dyn_cast_or_null( + if (auto rhs = dyn_cast_or_null( deriveStaticUpperBound(multiplyOp.getOperand(1), rewriter) .getDefiningOp())) { // Assumptions about the upper bound of minimum computations no longer // work if multiplied by a negative value, so abort in this case. - if (lhs.getValue() < 0 || rhs.getValue() < 0) + if (lhs.value() < 0 || rhs.value() < 0) return {}; - return rewriter.create( - multiplyOp.getLoc(), lhs.getValue() * rhs.getValue()); + return rewriter.create( + multiplyOp.getLoc(), lhs.value() * rhs.value()); } } @@ -416,8 +419,9 @@ launchIndependent](Value val) -> Value { if (launchIndependent(val)) return val; - if (ConstantOp constOp = val.getDefiningOp()) - return rewriter.create(constOp.getLoc(), constOp.getValue()); + if (auto constOp = val.getDefiningOp()) + return rewriter.create(constOp.getLoc(), + constOp.value()); return {}; }; @@ -460,17 +464,17 @@ // conditional. If the lower-bound is constant or defined before the // launch, we can use it in the launch bounds. Otherwise fail. if (!launchIndependent(lowerBound) && - !isa_and_nonnull(lowerBound.getDefiningOp())) + !isa_and_nonnull(lowerBound.getDefiningOp())) return failure(); // The step must also be constant or defined outside of the loop nest. if (!launchIndependent(step) && - !isa_and_nonnull(step.getDefiningOp())) + !isa_and_nonnull(step.getDefiningOp())) return failure(); // If the upper-bound is constant or defined before the launch, we can // use it in the launch bounds directly. Otherwise try derive a bound. bool boundIsPrecise = launchIndependent(upperBound) || - isa_and_nonnull(upperBound.getDefiningOp()); + isa_and_nonnull(upperBound.getDefiningOp()); { PatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(launchOp); @@ -510,8 +514,8 @@ if (!boundIsPrecise) { // We are using an approximation, create a surrounding conditional. Value originalBound = std::get<3>(config); - CmpIOp pred = rewriter.create( - loc, CmpIPredicate::slt, newIndex, + arith::CmpIOp pred = rewriter.create( + loc, arith::CmpIPredicate::slt, newIndex, cloningMap.lookupOrDefault(originalBound)); scf::IfOp ifOp = rewriter.create(loc, pred, false); rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); @@ -595,7 +599,8 @@ // Create a launch operation. We start with bound one for all grid/block // sizes. Those will be refined later as we discover them from mappings. Location loc = parallelOp.getLoc(); - Value constantOne = rewriter.create(parallelOp.getLoc(), 1); + Value constantOne = + rewriter.create(parallelOp.getLoc(), 1); gpu::LaunchOp launchOp = rewriter.create( parallelOp.getLoc(), constantOne, constantOne, constantOne, constantOne, constantOne, constantOne); diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp @@ -10,6 +10,7 @@ #include "../PassDetail.h" #include "mlir/Conversion/SCFToGPU/SCFToGPU.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/SCF/SCF.h" diff --git a/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt b/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt --- a/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt +++ b/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt @@ -12,6 +12,7 @@ LINK_LIBS PUBLIC MLIRAnalysis + MLIRArithmetic MLIRLLVMIR MLIROpenMP MLIRSCF diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -14,6 +14,7 @@ #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h" #include "../PassDetail.h" #include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/SCF.h" @@ -248,27 +249,27 @@ // Match simple binary reductions that can be expressed with atomicrmw. Type type = reduce.operand().getType(); Block &reduction = reduce.getRegion().front(); - if (matchSimpleReduction(reduction)) { + if (matchSimpleReduction(reduction)) { omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, builder.getFloatAttr(type, 0.0)); return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce); } - if (matchSimpleReduction(reduction)) { + if (matchSimpleReduction(reduction)) { omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, builder.getIntegerAttr(type, 0)); return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce); } - if (matchSimpleReduction(reduction)) { + if (matchSimpleReduction(reduction)) { omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, builder.getIntegerAttr(type, 0)); return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce); } - if (matchSimpleReduction(reduction)) { + if (matchSimpleReduction(reduction)) { omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, builder.getIntegerAttr(type, 0)); return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce); } - if (matchSimpleReduction(reduction)) { + if (matchSimpleReduction(reduction)) { omp::ReductionDeclareOp decl = createDecl( builder, symbolTable, reduce, builder.getIntegerAttr( @@ -279,25 +280,25 @@ // Match simple binary reductions that cannot be expressed with atomicrmw. // TODO: add atomic region using cmpxchg (which needs atomic load to be // available as an op). - if (matchSimpleReduction(reduction)) { + if (matchSimpleReduction(reduction)) { return createDecl(builder, symbolTable, reduce, builder.getFloatAttr(type, 1.0)); } // Match select-based min/max reductions. bool isMin; - if (matchSelectReduction( - reduction, {CmpFPredicate::OLT, CmpFPredicate::OLE}, - {CmpFPredicate::OGT, CmpFPredicate::OGE}, isMin) || + if (matchSelectReduction( + reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE}, + {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) || matchSelectReduction( reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole}, {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) { return createDecl(builder, symbolTable, reduce, minMaxValueForFloat(type, !isMin)); } - if (matchSelectReduction( - reduction, {CmpIPredicate::slt, CmpIPredicate::sle}, - {CmpIPredicate::sgt, CmpIPredicate::sge}, isMin) || + if (matchSelectReduction( + reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle}, + {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) || matchSelectReduction( reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle}, {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) { @@ -307,9 +308,9 @@ isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max, decl, reduce); } - if (matchSelectReduction( - reduction, {CmpIPredicate::ult, CmpIPredicate::ule}, - {CmpIPredicate::ugt, CmpIPredicate::uge}, isMin) || + if (matchSelectReduction( + reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule}, + {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) || matchSelectReduction( reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule}, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) { diff --git a/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt --- a/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt @@ -9,6 +9,7 @@ MLIRConversionPassIncGen LINK_LIBS PUBLIC + MLIRArithmeticToSPIRV MLIRMemRefToSPIRV MLIRSPIRV MLIRSPIRVConversion diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h" #include "../PassDetail.h" +#include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h" #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h" #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h" @@ -43,6 +44,7 @@ // TODO: Change SPIR-V conversion to be progressive and remove the following // patterns. + populateArithmeticToSPIRVPatterns(typeConverter, patterns); populateStandardToSPIRVPatterns(typeConverter, patterns); populateMemRefToSPIRVPatterns(typeConverter, patterns); populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns); diff --git a/mlir/lib/Conversion/SCFToStandard/CMakeLists.txt b/mlir/lib/Conversion/SCFToStandard/CMakeLists.txt --- a/mlir/lib/Conversion/SCFToStandard/CMakeLists.txt +++ b/mlir/lib/Conversion/SCFToStandard/CMakeLists.txt @@ -11,6 +11,7 @@ Core LINK_LIBS PUBLIC + MLIRArithmetic MLIRSCF MLIRTransforms ) diff --git a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp --- a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp +++ b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "../PassDetail.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" @@ -314,7 +315,7 @@ Operation *terminator = lastBodyBlock->getTerminator(); rewriter.setInsertionPointToEnd(lastBodyBlock); auto step = forOp.step(); - auto stepped = rewriter.create(loc, iv, step).getResult(); + auto stepped = rewriter.create(loc, iv, step).getResult(); if (!stepped) return failure(); @@ -341,8 +342,8 @@ // With the body block done, we can fill in the condition block. rewriter.setInsertionPointToEnd(conditionBlock); - auto comparison = - rewriter.create(loc, CmpIPredicate::slt, iv, upperBound); + auto comparison = rewriter.create( + loc, arith::CmpIPredicate::slt, iv, upperBound); rewriter.create(loc, comparison, firstBodyBlock, ArrayRef(), endBlock, ArrayRef()); diff --git a/mlir/lib/Conversion/SPIRVToLLVM/CMakeLists.txt b/mlir/lib/Conversion/SPIRVToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/SPIRVToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/SPIRVToLLVM/CMakeLists.txt @@ -11,6 +11,7 @@ intrinsics_gen LINK_LIBS PUBLIC + MLIRArithmeticToLLVM MLIRGPUOps MLIRSPIRV MLIRSPIRVUtils diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "../PassDetail.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" @@ -287,6 +288,7 @@ auto *context = module.getContext(); RewritePatternSet patterns(context); LLVMTypeConverter typeConverter(context, options); + populateArithmeticToLLVMConversionPatterns(typeConverter, patterns); populateMemRefToLLVMConversionPatterns(typeConverter, patterns); populateStdToLLVMConversionPatterns(typeConverter, patterns); patterns.add(typeConverter); diff --git a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt --- a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt +++ b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt @@ -17,6 +17,7 @@ Core LINK_LIBS PUBLIC + MLIRArithmetic MLIRIR MLIRShape MLIRTensor diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" #include "../PassDetail.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -75,13 +76,13 @@ // number of extent tensors and shifted offsets into them. Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors, ValueRange rankDiffs, Value outputDimension) { - Value one = lb.create(1); + Value one = lb.create(1); Value broadcastedDim = one; for (auto tup : llvm::zip(extentTensors, rankDiffs)) { Value shape = std::get<0>(tup); Value rankDiff = std::get<1>(tup); - Value outOfBounds = - lb.create(CmpIPredicate::ult, outputDimension, rankDiff); + Value outOfBounds = lb.create(arith::CmpIPredicate::ult, + outputDimension, rankDiff); Type indexTy = lb.getIndexType(); broadcastedDim = lb.create( @@ -97,13 +98,14 @@ // - otherwise, take the extent as-is. // Note that this logic remains correct in the presence // of dimensions of zero extent. - Value lesserRankOperandDimension = - b.create(loc, indexTy, outputDimension, rankDiff); + Value lesserRankOperandDimension = b.create( + loc, indexTy, outputDimension, rankDiff); Value lesserRankOperandExtent = b.create( loc, shape, ValueRange{lesserRankOperandDimension}); - Value dimIsOne = b.create(loc, CmpIPredicate::eq, - lesserRankOperandExtent, one); + Value dimIsOne = + b.create(loc, arith::CmpIPredicate::eq, + lesserRankOperandExtent, one); Value dim = b.create(loc, dimIsOne, broadcastedDim, lesserRankOperandExtent); b.create(loc, dim); @@ -125,7 +127,7 @@ auto loc = op.getLoc(); ImplicitLocOpBuilder lb(loc, rewriter); - Value zero = lb.create(0); + Value zero = lb.create(0); Type indexTy = lb.getIndexType(); // Save all the ranks for bounds checking. Because this is a tensor @@ -139,13 +141,14 @@ // Find the maximum rank Value maxRank = ranks.front(); for (Value v : llvm::drop_begin(ranks, 1)) { - Value rankIsGreater = lb.create(CmpIPredicate::ugt, v, maxRank); + Value rankIsGreater = + lb.create(arith::CmpIPredicate::ugt, v, maxRank); maxRank = lb.create(rankIsGreater, v, maxRank); } // Calculate the difference of ranks and the maximum rank for later offsets. llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { - return lb.create(indexTy, maxRank, v); + return lb.create(indexTy, maxRank, v); })); Value replacement = lb.create( @@ -186,7 +189,7 @@ SmallVector extentOperands; for (auto extent : op.shape()) { extentOperands.push_back( - rewriter.create(loc, extent.getLimitedValue())); + rewriter.create(loc, extent.getLimitedValue())); } Type indexTy = rewriter.getIndexType(); Value tensor = @@ -210,7 +213,8 @@ LogicalResult ConstSizeOpConversion::matchAndRewrite( ConstSizeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - rewriter.replaceOpWithNewOp(op, op.value().getSExtValue()); + rewriter.replaceOpWithNewOp( + op, op.value().getSExtValue()); return success(); } @@ -236,8 +240,8 @@ auto loc = op.getLoc(); ImplicitLocOpBuilder lb(loc, rewriter); - Value zero = lb.create(0); - Value one = lb.create(1); + Value zero = lb.create(0); + Value one = lb.create(1); Type indexTy = lb.getIndexType(); // Save all the ranks for bounds checking. Because this is a tensor @@ -251,18 +255,19 @@ // Find the maximum rank Value maxRank = ranks.front(); for (Value v : llvm::drop_begin(ranks, 1)) { - Value rankIsGreater = lb.create(CmpIPredicate::ugt, v, maxRank); + Value rankIsGreater = + lb.create(arith::CmpIPredicate::ugt, v, maxRank); maxRank = lb.create(rankIsGreater, v, maxRank); } // Calculate the difference of ranks and the maximum rank for later offsets. llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { - return lb.create(indexTy, maxRank, v); + return lb.create(indexTy, maxRank, v); })); Type i1Ty = rewriter.getI1Type(); Value trueVal = - rewriter.create(loc, i1Ty, rewriter.getBoolAttr(true)); + rewriter.create(loc, i1Ty, rewriter.getBoolAttr(true)); auto reduceResult = lb.create( loc, zero, maxRank, one, ValueRange{trueVal}, @@ -277,8 +282,8 @@ for (auto tup : llvm::zip(adaptor.shapes(), rankDiffs)) { Value shape, rankDiff; std::tie(shape, rankDiff) = tup; - Value outOfBounds = - b.create(loc, CmpIPredicate::ult, iv, rankDiff); + Value outOfBounds = b.create( + loc, arith::CmpIPredicate::ult, iv, rankDiff); broadcastable = b.create( loc, TypeRange{i1Ty}, outOfBounds, @@ -290,18 +295,19 @@ // Every value needs to be either 1, or the same non-1 // value to be broadcastable in this dim. Value operandDimension = - b.create(loc, indexTy, iv, rankDiff); + b.create(loc, indexTy, iv, rankDiff); Value dimensionExtent = b.create( loc, shape, ValueRange{operandDimension}); - Value equalOne = b.create(loc, CmpIPredicate::eq, - dimensionExtent, one); - Value equalBroadcasted = - b.create(loc, CmpIPredicate::eq, - dimensionExtent, broadcastedDim); - Value result = b.create( + Value equalOne = b.create( + loc, arith::CmpIPredicate::eq, dimensionExtent, one); + Value equalBroadcasted = b.create( + loc, arith::CmpIPredicate::eq, dimensionExtent, + broadcastedDim); + Value result = b.create( loc, broadcastable, - b.create(loc, equalOne, equalBroadcasted)); + b.create(loc, equalOne, + equalBroadcasted)); b.create(loc, result); }) .getResult(0); @@ -389,8 +395,8 @@ auto loc = op.getLoc(); - Value zero = rewriter.create(loc, 0); - Value one = rewriter.create(loc, 1); + Value zero = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); Type indexTy = rewriter.getIndexType(); Value rank = rewriter.create(loc, indexTy, adaptor.shape(), zero); @@ -433,20 +439,20 @@ /// %c0 = constant 0 : index /// %0 = dim %arg0, %c0 : tensor /// %1 = dim %arg1, %c0 : tensor -/// %2 = cmpi "eq", %0, %1 : index +/// %2 = arith.cmpi "eq", %0, %1 : index /// %result = scf.if %2 -> (i1) { -/// %c1 = constant 1 : index -/// %true = constant true +/// %c1 = arith.const 1 : index +/// %true = arith.const true /// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) { /// %5 = tensor.extract %arg0[%arg2] : tensor /// %6 = tensor.extract %arg1[%arg2] : tensor -/// %7 = cmpi "eq", %5, %6 : index -/// %8 = and %arg3, %7 : i1 +/// %7 = arith.cmpi "eq", %5, %6 : index +/// %8 = arith.andi %arg3, %7 : i1 /// scf.yield %8 : i1 /// } /// scf.yield %4 : i1 /// } else { -/// %false = constant false +/// %false = arith.const false /// scf.yield %false : i1 /// } /// @@ -468,14 +474,14 @@ Type i1Ty = rewriter.getI1Type(); if (op.shapes().size() <= 1) { - rewriter.replaceOpWithNewOp(op, i1Ty, - rewriter.getBoolAttr(true)); + rewriter.replaceOpWithNewOp(op, i1Ty, + rewriter.getBoolAttr(true)); return success(); } auto loc = op.getLoc(); Type indexTy = rewriter.getIndexType(); - Value zero = rewriter.create(loc, 0); + Value zero = rewriter.create(loc, 0); Value firstShape = adaptor.shapes().front(); Value firstRank = rewriter.create(loc, indexTy, firstShape, zero); @@ -483,13 +489,14 @@ // Generate a linear sequence of compares, all with firstShape as lhs. for (Value shape : adaptor.shapes().drop_front(1)) { Value rank = rewriter.create(loc, indexTy, shape, zero); - Value eqRank = - rewriter.create(loc, CmpIPredicate::eq, firstRank, rank); + Value eqRank = rewriter.create(loc, arith::CmpIPredicate::eq, + firstRank, rank); auto same = rewriter.create( loc, i1Ty, eqRank, [&](OpBuilder &b, Location loc) { - Value one = b.create(loc, 1); - Value init = b.create(loc, i1Ty, b.getBoolAttr(true)); + Value one = b.create(loc, 1); + Value init = + b.create(loc, i1Ty, b.getBoolAttr(true)); auto loop = b.create( loc, zero, firstRank, one, ValueRange{init}, [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { @@ -497,19 +504,21 @@ Value lhsExtent = b.create(loc, firstShape, iv); Value rhsExtent = b.create(loc, shape, iv); - Value eqExtent = b.create(loc, CmpIPredicate::eq, - lhsExtent, rhsExtent); - Value conjNext = b.create(loc, conj, eqExtent); + Value eqExtent = b.create( + loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent); + Value conjNext = b.create(loc, conj, eqExtent); b.create(loc, ValueRange({conjNext})); }); b.create(loc, loop.getResults()); }, [&](OpBuilder &b, Location loc) { - Value result = b.create(loc, i1Ty, b.getBoolAttr(false)); + Value result = + b.create(loc, i1Ty, b.getBoolAttr(false)); b.create(loc, result); }); result = !result ? same.getResult(0) - : rewriter.create(loc, result, same.getResult(0)); + : rewriter.create(loc, result, + same.getResult(0)); } rewriter.replaceOp(op, result); return success(); @@ -549,8 +558,8 @@ Value extent = rewriter.create(loc, tensor, i); extentValues.push_back(extent); } else { - Value extent = - rewriter.create(loc, rankedTensorTy.getDimSize(i)); + Value extent = rewriter.create( + loc, rankedTensorTy.getDimSize(i)); extentValues.push_back(extent); } } @@ -598,20 +607,20 @@ return failure(); ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value zero = b.create(0); + Value zero = b.create(0); Value rank = b.create(adaptor.operand(), zero); // index < 0 ? index + rank : index Value originalIndex = adaptor.index(); - Value add = b.create(originalIndex, rank); + Value add = b.create(originalIndex, rank); Value indexIsNegative = - b.create(CmpIPredicate::slt, originalIndex, zero); + b.create(arith::CmpIPredicate::slt, originalIndex, zero); Value index = b.create(indexIsNegative, add, originalIndex); - Value one = b.create(1); + Value one = b.create(1); Value head = b.create(adaptor.operand(), zero, index, one); - Value tailSize = b.create(rank, index); + Value tailSize = b.create(rank, index); Value tail = b.create(adaptor.operand(), index, tailSize, one); rewriter.replaceOp(op, {head, tail}); @@ -655,8 +664,8 @@ // Setup target legality. MLIRContext &ctx = getContext(); ConversionTarget target(ctx); - target - .addLegalDialect(); + target.addLegalDialect(); target.addLegalOp(); // Setup conversion patterns. @@ -675,8 +684,8 @@ populateWithGenerated(patterns); patterns.add< AnyOpConversion, - BinaryOpConversion, - BinaryOpConversion, + BinaryOpConversion, + BinaryOpConversion, BroadcastOpConverter, ConstShapeOpConverter, ConstSizeOpConversion, diff --git a/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt @@ -13,6 +13,7 @@ LINK_LIBS PUBLIC MLIRAnalysis + MLIRArithmeticToLLVM MLIRDataLayoutInterfaces MLIRLLVMCommonConversion MLIRLLVMIR diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -13,6 +13,7 @@ #include "../PassDetail.h" #include "mlir/Analysis/DataLayoutAnalysis.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/VectorPattern.h" @@ -390,54 +391,7 @@ }; // Straightforward lowerings. -using AbsFOpLowering = VectorConvertToLLVMPattern; -using AddFOpLowering = VectorConvertToLLVMPattern; -using AddIOpLowering = VectorConvertToLLVMPattern; -using AndOpLowering = VectorConvertToLLVMPattern; -using BitcastOpLowering = - VectorConvertToLLVMPattern; -using CeilFOpLowering = VectorConvertToLLVMPattern; -using CopySignOpLowering = - VectorConvertToLLVMPattern; -using DivFOpLowering = VectorConvertToLLVMPattern; -using FPExtOpLowering = VectorConvertToLLVMPattern; -using FPToSIOpLowering = VectorConvertToLLVMPattern; -using FPToUIOpLowering = VectorConvertToLLVMPattern; -using FPTruncOpLowering = - VectorConvertToLLVMPattern; -using FloorFOpLowering = VectorConvertToLLVMPattern; -using FmaFOpLowering = VectorConvertToLLVMPattern; -using MulFOpLowering = VectorConvertToLLVMPattern; -using MulIOpLowering = VectorConvertToLLVMPattern; -using NegFOpLowering = VectorConvertToLLVMPattern; -using OrOpLowering = VectorConvertToLLVMPattern; -using RemFOpLowering = VectorConvertToLLVMPattern; -using SIToFPOpLowering = VectorConvertToLLVMPattern; using SelectOpLowering = VectorConvertToLLVMPattern; -using SignExtendIOpLowering = - VectorConvertToLLVMPattern; -using ShiftLeftOpLowering = - VectorConvertToLLVMPattern; -using SignedDivIOpLowering = - VectorConvertToLLVMPattern; -using SignedRemIOpLowering = - VectorConvertToLLVMPattern; -using SignedShiftRightOpLowering = - VectorConvertToLLVMPattern; -using SubFOpLowering = VectorConvertToLLVMPattern; -using SubIOpLowering = VectorConvertToLLVMPattern; -using TruncateIOpLowering = - VectorConvertToLLVMPattern; -using UIToFPOpLowering = VectorConvertToLLVMPattern; -using UnsignedDivIOpLowering = - VectorConvertToLLVMPattern; -using UnsignedRemIOpLowering = - VectorConvertToLLVMPattern; -using UnsignedShiftRightOpLowering = - VectorConvertToLLVMPattern; -using XOrOpLowering = VectorConvertToLLVMPattern; -using ZeroExtendIOpLowering = - VectorConvertToLLVMPattern; /// Lower `std.assert`. The default lowering calls the `abort` function if the /// assertion is violated and has no effect otherwise. The failure message is @@ -651,118 +605,6 @@ } }; -// The lowering of index_cast becomes an integer conversion since index becomes -// an integer. If the bit width of the source and target integer types is the -// same, just erase the cast. If the target type is wider, sign-extend the -// value, otherwise truncate it. -struct IndexCastOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(IndexCastOp indexCastOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto targetType = - typeConverter->convertType(indexCastOp.getResult().getType()); - auto targetElementType = - typeConverter - ->convertType(getElementTypeOrSelf(indexCastOp.getResult())) - .cast(); - auto sourceElementType = - getElementTypeOrSelf(adaptor.in()).cast(); - unsigned targetBits = targetElementType.getWidth(); - unsigned sourceBits = sourceElementType.getWidth(); - - if (targetBits == sourceBits) - rewriter.replaceOp(indexCastOp, adaptor.in()); - else if (targetBits < sourceBits) - rewriter.replaceOpWithNewOp(indexCastOp, targetType, - adaptor.in()); - else - rewriter.replaceOpWithNewOp(indexCastOp, targetType, - adaptor.in()); - return success(); - } -}; - -// Convert std.cmp predicate into the LLVM dialect CmpPredicate. The two -// enums share the numerical values so just cast. -template -static LLVMPredType convertCmpPredicate(StdPredType pred) { - return static_cast(pred); -} - -struct CmpIOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(CmpIOp cmpiOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto operandType = adaptor.lhs().getType(); - auto resultType = cmpiOp.getResult().getType(); - - // Handle the scalar and 1D vector cases. - if (!operandType.isa()) { - rewriter.replaceOpWithNewOp( - cmpiOp, typeConverter->convertType(resultType), - convertCmpPredicate(cmpiOp.getPredicate()), - adaptor.lhs(), adaptor.rhs()); - return success(); - } - - auto vectorType = resultType.dyn_cast(); - if (!vectorType) - return rewriter.notifyMatchFailure(cmpiOp, "expected vector result type"); - - return LLVM::detail::handleMultidimensionalVectors( - cmpiOp.getOperation(), adaptor.getOperands(), *getTypeConverter(), - [&](Type llvm1DVectorTy, ValueRange operands) { - CmpIOpAdaptor adaptor(operands); - return rewriter.create( - cmpiOp.getLoc(), llvm1DVectorTy, - convertCmpPredicate(cmpiOp.getPredicate()), - adaptor.lhs(), adaptor.rhs()); - }, - rewriter); - - return success(); - } -}; - -struct CmpFOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(CmpFOp cmpfOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto operandType = adaptor.lhs().getType(); - auto resultType = cmpfOp.getResult().getType(); - - // Handle the scalar and 1D vector cases. - if (!operandType.isa()) { - rewriter.replaceOpWithNewOp( - cmpfOp, typeConverter->convertType(resultType), - convertCmpPredicate(cmpfOp.getPredicate()), - adaptor.lhs(), adaptor.rhs()); - return success(); - } - - auto vectorType = resultType.dyn_cast(); - if (!vectorType) - return rewriter.notifyMatchFailure(cmpfOp, "expected vector result type"); - - return LLVM::detail::handleMultidimensionalVectors( - cmpfOp.getOperation(), adaptor.getOperands(), *getTypeConverter(), - [&](Type llvm1DVectorTy, ValueRange operands) { - CmpFOpAdaptor adaptor(operands); - return rewriter.create( - cmpfOp.getLoc(), llvm1DVectorTy, - convertCmpPredicate(cmpfOp.getPredicate()), - adaptor.lhs(), adaptor.rhs()); - }, - rewriter); - } -}; - // Base class for LLVM IR lowering terminator operations with successors. template struct OneToOneLLVMTerminatorLowering @@ -1131,57 +973,20 @@ populateStdToLLVMFuncOpConversionPattern(converter, patterns); // clang-format off patterns.add< - AbsFOpLowering, - AddFOpLowering, - AddIOpLowering, - AndOpLowering, AssertOpLowering, AtomicRMWOpLowering, - BitcastOpLowering, BranchOpLowering, CallIndirectOpLowering, CallOpLowering, - CeilFOpLowering, - CmpFOpLowering, - CmpIOpLowering, CondBranchOpLowering, - CopySignOpLowering, ConstantOpLowering, - DivFOpLowering, - FloorFOpLowering, - FmaFOpLowering, GenericAtomicRMWOpLowering, - FPExtOpLowering, - FPToSIOpLowering, - FPToUIOpLowering, - FPTruncOpLowering, - IndexCastOpLowering, - MulFOpLowering, - MulIOpLowering, - NegFOpLowering, - OrOpLowering, - RemFOpLowering, RankOpLowering, ReturnOpLowering, - SIToFPOpLowering, SelectOpLowering, - ShiftLeftOpLowering, - SignExtendIOpLowering, - SignedDivIOpLowering, - SignedRemIOpLowering, - SignedShiftRightOpLowering, SplatOpLowering, SplatNdOpLowering, - SubFOpLowering, - SubIOpLowering, - SwitchOpLowering, - TruncateIOpLowering, - UIToFPOpLowering, - UnsignedDivIOpLowering, - UnsignedRemIOpLowering, - UnsignedShiftRightOpLowering, - XOrOpLowering, - ZeroExtendIOpLowering>(converter); + SwitchOpLowering>(converter); // clang-format on } @@ -1231,6 +1036,7 @@ RewritePatternSet patterns(&getContext()); populateStdToLLVMConversionPatterns(typeConverter, patterns); + populateArithmeticToLLVMConversionPatterns(typeConverter, patterns); LLVMConversionTarget target(getContext()); if (failed(applyPartialConversion(m, target, std::move(patterns)))) diff --git a/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt --- a/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt @@ -10,8 +10,9 @@ MLIRConversionPassIncGen LINK_LIBS PUBLIC + MLIRArithmeticToSPIRV MLIRIR - MLIRMath + MLIRMathToSPIRV MLIRMemRef MLIRPass MLIRSPIRV diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Conversion/SPIRVCommon/Pattern.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" @@ -98,35 +99,6 @@ return builder.getF32FloatAttr(dstVal.convertToFloat()); } -/// Returns signed remainder for `lhs` and `rhs` and lets the result follow -/// the sign of `signOperand`. -/// -/// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment -/// spec, "for the OpSRem and OpSMod instructions, if either operand is negative -/// the result is undefined." So we cannot directly use spv.SRem/spv.SMod -/// if either operand can be negative. Emulate it via spv.UMod. -static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, - Value signOperand, OpBuilder &builder) { - assert(lhs.getType() == rhs.getType()); - assert(lhs == signOperand || rhs == signOperand); - - Type type = lhs.getType(); - - // Calculate the remainder with spv.UMod. - Value lhsAbs = builder.create(loc, type, lhs); - Value rhsAbs = builder.create(loc, type, rhs); - Value abs = builder.create(loc, lhsAbs, rhsAbs); - - // Fix the sign. - Value isPositive; - if (lhs == signOperand) - isPositive = builder.create(loc, lhs, lhsAbs); - else - isPositive = builder.create(loc, rhs, rhsAbs); - Value absNegate = builder.create(loc, type, abs); - return builder.create(loc, type, isPositive, abs, absNegate); -} - //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// @@ -137,71 +109,6 @@ namespace { -/// Converts unary and binary standard operations to SPIR-V operations. -template -class UnaryAndBinaryOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - assert(adaptor.getOperands().size() <= 2); - auto dstType = this->getTypeConverter()->convertType(operation.getType()); - if (!dstType) - return failure(); - if (SPIRVOp::template hasTrait() && - dstType != operation.getType()) { - return operation.emitError( - "bitwidth emulation is not implemented yet on unsigned op"); - } - rewriter.template replaceOpWithNewOp(operation, dstType, - adaptor.getOperands()); - return success(); - } -}; - -/// Converts std.remi_signed to SPIR-V ops. -/// -/// This cannot be merged into the template unary/binary pattern due to -/// Vulkan restrictions over spv.SRem and spv.SMod. -class SignedRemIOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(SignedRemIOp remOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts bitwise standard operations to SPIR-V operations. This is a special -/// pattern other than the BinaryOpPatternPattern because if the operands are -/// boolean values, SPIR-V uses different operations (`SPIRVLogicalOp`). For -/// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. -template -class BitwiseOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - assert(adaptor.getOperands().size() == 2); - auto dstType = - this->getTypeConverter()->convertType(operation.getResult().getType()); - if (!dstType) - return failure(); - if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) { - rewriter.template replaceOpWithNewOp( - operation, dstType, adaptor.getOperands()); - } else { - rewriter.template replaceOpWithNewOp( - operation, dstType, adaptor.getOperands()); - } - return success(); - } -}; - /// Converts composite std.constant operation to spv.Constant. class ConstantCompositeOpPattern final : public OpConversionPattern { @@ -223,58 +130,6 @@ ConversionPatternRewriter &rewriter) const override; }; -/// Converts floating-point comparison operations to SPIR-V ops. -class CmpFOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(CmpFOp cmpFOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts floating point NaN check to SPIR-V ops. This pattern requires -/// Kernel capability. -class CmpFOpNanKernelPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(CmpFOp cmpFOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts floating point NaN check to SPIR-V ops. This pattern does not -/// require additional capability. -class CmpFOpNanNonePattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(CmpFOp cmpFOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts integer compare operation on i1 type operands to SPIR-V ops. -class BoolCmpIOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(CmpIOp cmpIOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts integer compare operation to SPIR-V ops. -class CmpIOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(CmpIOp cmpIOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - /// Converts std.return to spv.Return. class ReturnOpPattern final : public OpConversionPattern { public: @@ -304,30 +159,6 @@ ConversionPatternRewriter &rewriter) const override; }; -/// Converts std.zexti to spv.Select if the type of source is i1 or vector of -/// i1. -class ZeroExtendI1Pattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(ZeroExtendIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto srcType = adaptor.getOperands().front().getType(); - if (!isBoolScalarOrVector(srcType)) - return failure(); - - auto dstType = - this->getTypeConverter()->convertType(op.getResult().getType()); - Location loc = op.getLoc(); - Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); - Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); - rewriter.template replaceOpWithNewOp( - op, dstType, adaptor.getOperands().front(), one, zero); - return success(); - } -}; - /// Converts tensor.extract into loading using access chains from SPIR-V local /// variables. class TensorExtractPattern final @@ -389,124 +220,8 @@ int64_t byteCountThreshold; }; -/// Converts std.trunci to spv.Select if the type of result is i1 or vector of -/// i1. -class TruncI1Pattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(TruncateIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto dstType = - this->getTypeConverter()->convertType(op.getResult().getType()); - if (!isBoolScalarOrVector(dstType)) - return failure(); - - Location loc = op.getLoc(); - auto srcType = adaptor.getOperands().front().getType(); - // Check if (x & 1) == 1. - Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); - Value maskedSrc = rewriter.create( - loc, srcType, adaptor.getOperands()[0], mask); - Value isOne = rewriter.create(loc, maskedSrc, mask); - - Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); - Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); - rewriter.replaceOpWithNewOp(op, dstType, isOne, one, zero); - return success(); - } -}; - -/// Converts std.uitofp to spv.Select if the type of source is i1 or vector of -/// i1. -class UIToFPI1Pattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(UIToFPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto srcType = adaptor.getOperands().front().getType(); - if (!isBoolScalarOrVector(srcType)) - return failure(); - - auto dstType = - this->getTypeConverter()->convertType(op.getResult().getType()); - Location loc = op.getLoc(); - Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); - Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); - rewriter.template replaceOpWithNewOp( - op, dstType, adaptor.getOperands().front(), one, zero); - return success(); - } -}; - -/// Converts type-casting standard operations to SPIR-V operations. -template -class TypeCastingOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - assert(adaptor.getOperands().size() == 1); - auto srcType = adaptor.getOperands().front().getType(); - auto dstType = - this->getTypeConverter()->convertType(operation.getResult().getType()); - if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) - return failure(); - if (dstType == srcType) { - // Due to type conversion, we are seeing the same source and target type. - // Then we can just erase this operation by forwarding its operand. - rewriter.replaceOp(operation, adaptor.getOperands().front()); - } else { - rewriter.template replaceOpWithNewOp(operation, dstType, - adaptor.getOperands()); - } - return success(); - } -}; - -/// Converts std.xor to SPIR-V operations. -class XOrOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(XOrOp xorOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts std.xor to SPIR-V operations if the type of source is i1 or vector -/// of i1. -class BoolXOrOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(XOrOp xorOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - } // namespace -//===----------------------------------------------------------------------===// -// SignedRemIOpPattern -//===----------------------------------------------------------------------===// - -LogicalResult SignedRemIOpPattern::matchAndRewrite( - SignedRemIOp remOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value result = emulateSignedRemainder( - remOp.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], - adaptor.getOperands()[0], rewriter); - rewriter.replaceOp(remOp, result); - - return success(); -} - //===----------------------------------------------------------------------===// // ConstantOp with composite type. //===----------------------------------------------------------------------===// @@ -649,143 +364,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// CmpFOp -//===----------------------------------------------------------------------===// - -LogicalResult -CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - switch (cmpFOp.getPredicate()) { -#define DISPATCH(cmpPredicate, spirvOp) \ - case cmpPredicate: \ - rewriter.replaceOpWithNewOp(cmpFOp, cmpFOp.getResult().getType(), \ - adaptor.lhs(), adaptor.rhs()); \ - return success(); - - // Ordered. - DISPATCH(CmpFPredicate::OEQ, spirv::FOrdEqualOp); - DISPATCH(CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); - DISPATCH(CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); - DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp); - DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); - DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp); - // Unordered. - DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp); - DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); - DISPATCH(CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); - DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp); - DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); - DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp); - -#undef DISPATCH - - default: - break; - } - return failure(); -} - -LogicalResult CmpFOpNanKernelPattern::matchAndRewrite( - CmpFOp cmpFOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (cmpFOp.getPredicate() == CmpFPredicate::ORD) { - rewriter.replaceOpWithNewOp(cmpFOp, adaptor.lhs(), - adaptor.rhs()); - return success(); - } - - if (cmpFOp.getPredicate() == CmpFPredicate::UNO) { - rewriter.replaceOpWithNewOp(cmpFOp, adaptor.lhs(), - adaptor.rhs()); - return success(); - } - - return failure(); -} - -LogicalResult CmpFOpNanNonePattern::matchAndRewrite( - CmpFOp cmpFOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (cmpFOp.getPredicate() != CmpFPredicate::ORD && - cmpFOp.getPredicate() != CmpFPredicate::UNO) - return failure(); - - Location loc = cmpFOp.getLoc(); - - Value lhsIsNan = rewriter.create(loc, adaptor.lhs()); - Value rhsIsNan = rewriter.create(loc, adaptor.rhs()); - - Value replace = rewriter.create(loc, lhsIsNan, rhsIsNan); - if (cmpFOp.getPredicate() == CmpFPredicate::ORD) - replace = rewriter.create(loc, replace); - - rewriter.replaceOp(cmpFOp, replace); - return success(); -} - -//===----------------------------------------------------------------------===// -// CmpIOp -//===----------------------------------------------------------------------===// - -LogicalResult -BoolCmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Type operandType = cmpIOp.lhs().getType(); - if (!isBoolScalarOrVector(operandType)) - return failure(); - - switch (cmpIOp.getPredicate()) { -#define DISPATCH(cmpPredicate, spirvOp) \ - case cmpPredicate: \ - rewriter.replaceOpWithNewOp(cmpIOp, cmpIOp.getResult().getType(), \ - adaptor.lhs(), adaptor.rhs()); \ - return success(); - - DISPATCH(CmpIPredicate::eq, spirv::LogicalEqualOp); - DISPATCH(CmpIPredicate::ne, spirv::LogicalNotEqualOp); - -#undef DISPATCH - default:; - } - return failure(); -} - -LogicalResult -CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Type operandType = cmpIOp.lhs().getType(); - if (isBoolScalarOrVector(operandType)) - return failure(); - - switch (cmpIOp.getPredicate()) { -#define DISPATCH(cmpPredicate, spirvOp) \ - case cmpPredicate: \ - if (spirvOp::template hasTrait() && \ - operandType != this->getTypeConverter()->convertType(operandType)) { \ - return cmpIOp.emitError( \ - "bitwidth emulation is not implemented yet on unsigned op"); \ - } \ - rewriter.replaceOpWithNewOp(cmpIOp, cmpIOp.getResult().getType(), \ - adaptor.lhs(), adaptor.rhs()); \ - return success(); - - DISPATCH(CmpIPredicate::eq, spirv::IEqualOp); - DISPATCH(CmpIPredicate::ne, spirv::INotEqualOp); - DISPATCH(CmpIPredicate::slt, spirv::SLessThanOp); - DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp); - DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp); - DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp); - DISPATCH(CmpIPredicate::ult, spirv::ULessThanOp); - DISPATCH(CmpIPredicate::ule, spirv::ULessThanEqualOp); - DISPATCH(CmpIPredicate::ugt, spirv::UGreaterThanOp); - DISPATCH(CmpIPredicate::uge, spirv::UGreaterThanEqualOp); - -#undef DISPATCH - } - return failure(); -} - //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// @@ -833,43 +411,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// XorOp -//===----------------------------------------------------------------------===// - -LogicalResult -XOrOpPattern::matchAndRewrite(XOrOp xorOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - assert(adaptor.getOperands().size() == 2); - - if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) - return failure(); - - auto dstType = getTypeConverter()->convertType(xorOp.getType()); - if (!dstType) - return failure(); - rewriter.replaceOpWithNewOp(xorOp, dstType, - adaptor.getOperands()); - - return success(); -} - -LogicalResult -BoolXOrOpPattern::matchAndRewrite(XOrOp xorOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - assert(adaptor.getOperands().size() == 2); - - if (!isBoolScalarOrVector(adaptor.getOperands().front().getType())) - return failure(); - - auto dstType = getTypeConverter()->convertType(xorOp.getType()); - if (!dstType) - return failure(); - rewriter.replaceOpWithNewOp(xorOp, dstType, - adaptor.getOperands()); - return success(); -} - //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// @@ -880,55 +421,12 @@ MLIRContext *context = patterns.getContext(); patterns.add< - // Unary and binary patterns - BitwiseOpPattern, - BitwiseOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - UnaryAndBinaryOpPattern, - SignedRemIOpPattern, XOrOpPattern, BoolXOrOpPattern, - - // Comparison patterns - BoolCmpIOpPattern, CmpFOpPattern, CmpFOpNanNonePattern, CmpIOpPattern, - // Constant patterns ConstantCompositeOpPattern, ConstantScalarOpPattern, - ReturnOpPattern, SelectOpPattern, SplatPattern, - - // Type cast patterns - UIToFPI1Pattern, ZeroExtendI1Pattern, TruncI1Pattern, - TypeCastingOpPattern, - TypeCastingOpPattern, - TypeCastingOpPattern, - TypeCastingOpPattern, - TypeCastingOpPattern, - TypeCastingOpPattern, - TypeCastingOpPattern, - TypeCastingOpPattern, - TypeCastingOpPattern>(typeConverter, - context); - - // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel - // capability is available. - patterns.add(typeConverter, context, - /*benefit=*/2); + ReturnOpPattern, SelectOpPattern, SplatPattern + + >(typeConverter, context); } void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp @@ -12,6 +12,8 @@ #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h" #include "../PassDetail.h" +#include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" +#include "mlir/Conversion/MathToSPIRV/MathToSPIRV.h" #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" @@ -39,6 +41,8 @@ SPIRVTypeConverter typeConverter(targetAttr, options); RewritePatternSet patterns(context); + populateArithmeticToSPIRVPatterns(typeConverter, patterns); + populateMathToSPIRVPatterns(typeConverter, patterns); populateStandardToSPIRVPatterns(typeConverter, patterns); populateTensorToSPIRVPatterns(typeConverter, /*byteCountThreshold=*/64, patterns); diff --git a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt --- a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt +++ b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt @@ -10,6 +10,7 @@ MLIRConversionPassIncGen LINK_LIBS PUBLIC + MLIRArithmetic MLIRDialectUtils MLIRIR MLIRLinalg diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/SCF.h" @@ -32,12 +33,12 @@ } template -static mlir::ConstantOp +static arith::ConstantOp createConstFromIntAttribute(Operation *op, std::string attrName, Type requiredAttrType, OpBuilder &rewriter) { auto castedN = static_cast( op->getAttr(attrName).cast().getValue().getSExtValue()); - return rewriter.create( + return rewriter.create( op->getLoc(), IntegerAttr::get(requiredAttrType, castedN)); } @@ -50,9 +51,9 @@ } template -static mlir::SelectOp clampHelper(Location loc, Value arg, mlir::ConstantOp min, - mlir::ConstantOp max, P pred, - OpBuilder &rewriter) { +static mlir::SelectOp clampHelper(Location loc, Value arg, + arith::ConstantOp min, arith::ConstantOp max, + P pred, OpBuilder &rewriter) { auto smallerThanMin = rewriter.create(loc, pred, arg, min); auto minOrArg = rewriter.create(loc, smallerThanMin, min, arg); @@ -83,7 +84,7 @@ highIndices.push_back(rewriter.getIndexAttr(highPad)); } - Value padValue = rewriter.create(loc, padAttr); + Value padValue = rewriter.create(loc, padAttr); return linalg::PadTensorOp::createPadScalarOp( RankedTensorType::get(paddedShape, inputETy), input, padValue, @@ -101,30 +102,30 @@ // tosa::AbsOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); if (isa(op) && elementTy.isa()) { - auto zero = - rewriter.create(loc, rewriter.getZeroAttr(elementTy)); - auto cmp = - rewriter.create(loc, CmpIPredicate::sgt, args[0], zero); - auto neg = rewriter.create(loc, zero, args[0]); + auto zero = rewriter.create( + loc, rewriter.getZeroAttr(elementTy)); + auto cmp = rewriter.create(loc, arith::CmpIPredicate::sgt, + args[0], zero); + auto neg = rewriter.create(loc, zero, args[0]); return rewriter.create(loc, cmp, args[0], neg); } // tosa::AddOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::SubOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::MulOp if (isa(op) && elementTy.isa()) { @@ -133,18 +134,18 @@ "Cannot have shift value for float"); return nullptr; } - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); } // tosa::DivOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::ReciprocalOp if (isa(op) && elementTy.isa()) { auto one = - rewriter.create(loc, FloatAttr::get(elementTy, 1)); - return rewriter.create(loc, resultTypes, one, args[0]); + rewriter.create(loc, FloatAttr::get(elementTy, 1)); + return rewriter.create(loc, resultTypes, one, args[0]); } if (isa(op) && elementTy.isa()) { @@ -154,12 +155,12 @@ op->getAttr("shift").cast().getValue().getSExtValue(); if (shift > 0) { auto shiftConst = - rewriter.create(loc, shift, /*bitwidth=*/8); + rewriter.create(loc, shift, /*bitwidth=*/8); if (!a.getType().isInteger(32)) - a = rewriter.create(loc, rewriter.getI32Type(), a); + a = rewriter.create(loc, rewriter.getI32Type(), a); if (!b.getType().isInteger(32)) - b = rewriter.create(loc, rewriter.getI32Type(), b); + b = rewriter.create(loc, rewriter.getI32Type(), b); auto result = rewriter.create( loc, rewriter.getI32Type(), a, b, shiftConst, @@ -168,7 +169,7 @@ if (elementTy.isInteger(32)) return result; - return rewriter.create(loc, elementTy, result); + return rewriter.create(loc, elementTy, result); } int aWidth = a.getType().getIntOrFloatBitWidth(); @@ -176,22 +177,22 @@ int cWidth = resultTypes[0].getIntOrFloatBitWidth(); if (aWidth < cWidth) - a = rewriter.create(loc, resultTypes[0], a); + a = rewriter.create(loc, resultTypes[0], a); if (bWidth < cWidth) - b = rewriter.create(loc, resultTypes[0], b); + b = rewriter.create(loc, resultTypes[0], b); - return rewriter.create(loc, resultTypes, a, b); + return rewriter.create(loc, resultTypes, a, b); } // tosa::NegateOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); if (isa(op) && elementTy.isa() && !cast(op).quantization_info()) { auto constant = - rewriter.create(loc, IntegerAttr::get(elementTy, 0)); - return rewriter.create(loc, resultTypes, constant, args[0]); + rewriter.create(loc, IntegerAttr::get(elementTy, 0)); + return rewriter.create(loc, resultTypes, constant, args[0]); } if (isa(op) && elementTy.isa() && @@ -220,62 +221,59 @@ } Type intermediateType = rewriter.getIntegerType(intermediateBitWidth); - Value zpAddValue = rewriter.create( + Value zpAddValue = rewriter.create( loc, rewriter.getIntegerAttr(intermediateType, zpAdd)); // The negation can be applied by doing: // outputValue = inZp + outZp - inputValue - auto ext = rewriter.create(loc, intermediateType, args[0]); - auto sub = rewriter.create(loc, zpAddValue, ext); + auto ext = rewriter.create(loc, intermediateType, args[0]); + auto sub = rewriter.create(loc, zpAddValue, ext); // Clamp to the negation range. - auto min = rewriter.create( - loc, rewriter.getIntegerAttr( - intermediateType, - APInt::getSignedMinValue(inputBitWidth).getSExtValue())); - auto max = rewriter.create( - loc, rewriter.getIntegerAttr( - intermediateType, - APInt::getSignedMaxValue(inputBitWidth).getSExtValue())); - auto clamp = clampHelper(loc, sub, min, max, - CmpIPredicate::slt, rewriter); + auto min = rewriter.create( + loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(), + intermediateType); + auto max = rewriter.create( + loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(), + intermediateType); + auto clamp = clampHelper( + loc, sub, min, max, arith::CmpIPredicate::slt, rewriter); // Truncate to the final value. - return rewriter.create(loc, elementTy, clamp); + return rewriter.create(loc, elementTy, clamp); } // tosa::BitwiseAndOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::BitwiseOrOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::BitwiseNotOp if (isa(op) && elementTy.isa()) { auto allOnesAttr = rewriter.getIntegerAttr( elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth())); - auto allOnes = rewriter.create(loc, allOnesAttr); - return rewriter.create(loc, resultTypes, args[0], allOnes); + auto allOnes = rewriter.create(loc, allOnesAttr); + return rewriter.create(loc, resultTypes, args[0], allOnes); } // tosa::BitwiseXOrOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::LogicalLeftShiftOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::LogicalRightShiftOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::ArithmeticRightShiftOp if (isa(op) && elementTy.isa()) { - auto result = - rewriter.create(loc, resultTypes, args); + auto result = rewriter.create(loc, resultTypes, args); auto round = op->getAttr("round").cast().getValue(); if (!round) { return result; @@ -283,40 +281,40 @@ Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1); auto one = - rewriter.create(loc, IntegerAttr::get(elementTy, 1)); + rewriter.create(loc, IntegerAttr::get(elementTy, 1)); auto zero = - rewriter.create(loc, IntegerAttr::get(elementTy, 0)); + rewriter.create(loc, IntegerAttr::get(elementTy, 0)); auto i1one = - rewriter.create(loc, IntegerAttr::get(i1Ty, 1)); + rewriter.create(loc, IntegerAttr::get(i1Ty, 1)); // Checking that input2 != 0 - auto shiftValueGreaterThanZero = - rewriter.create(loc, CmpIPredicate::sgt, args[1], zero); + auto shiftValueGreaterThanZero = rewriter.create( + loc, arith::CmpIPredicate::sgt, args[1], zero); // Checking for the last bit of input1 to be 1 auto subtract = - rewriter.create(loc, resultTypes, args[1], one); - auto shifted = rewriter - .create(loc, resultTypes, - args[0], subtract) - ->getResults(); + rewriter.create(loc, resultTypes, args[1], one); + auto shifted = + rewriter.create(loc, resultTypes, args[0], subtract) + ->getResults(); auto truncated = - rewriter.create(loc, i1Ty, shifted, mlir::None); - auto isInputOdd = rewriter.create(loc, i1Ty, truncated, i1one); + rewriter.create(loc, i1Ty, shifted, mlir::None); + auto isInputOdd = + rewriter.create(loc, i1Ty, truncated, i1one); - auto shouldRound = rewriter.create( + auto shouldRound = rewriter.create( loc, i1Ty, shiftValueGreaterThanZero, isInputOdd); auto extended = - rewriter.create(loc, resultTypes, shouldRound); - return rewriter.create(loc, resultTypes, result, extended); + rewriter.create(loc, resultTypes, shouldRound); + return rewriter.create(loc, resultTypes, result, extended); } // tosa::ClzOp if (isa(op) && elementTy.isa()) { int bitWidth = elementTy.getIntOrFloatBitWidth(); auto zero = - rewriter.create(loc, IntegerAttr::get(elementTy, 0)); - auto leadingZeros = rewriter.create( + rewriter.create(loc, IntegerAttr::get(elementTy, 0)); + auto leadingZeros = rewriter.create( loc, IntegerAttr::get(elementTy, bitWidth)); SmallVector operands = {args[0], leadingZeros, zero}; @@ -332,8 +330,8 @@ Value input = before->getArgument(0); Value zero = before->getArgument(2); - Value inputLargerThanZero = - rewriter.create(loc, CmpIPredicate::ne, input, zero); + Value inputLargerThanZero = rewriter.create( + loc, arith::CmpIPredicate::ne, input, zero); rewriter.create(loc, inputLargerThanZero, before->getArguments()); } @@ -344,12 +342,12 @@ Value input = after->getArgument(0); Value leadingZeros = after->getArgument(1); - auto one = rewriter.create( + auto one = rewriter.create( loc, IntegerAttr::get(elementTy, 1)); - auto shifted = rewriter.create( - loc, resultTypes, input, one); + auto shifted = + rewriter.create(loc, resultTypes, input, one); auto leadingZerosMinusOne = - rewriter.create(loc, resultTypes, leadingZeros, one); + rewriter.create(loc, resultTypes, leadingZeros, one); rewriter.create( loc, @@ -362,22 +360,22 @@ // tosa::LogicalAnd if (isa(op) && elementTy.isInteger(1)) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::LogicalNot if (isa(op) && elementTy.isInteger(1)) { - auto one = rewriter.create( + auto one = rewriter.create( loc, rewriter.getIntegerAttr(elementTy, 1)); - return rewriter.create(loc, resultTypes, args[0], one); + return rewriter.create(loc, resultTypes, args[0], one); } // tosa::LogicalOr if (isa(op) && elementTy.isInteger(1)) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::LogicalXor if (isa(op) && elementTy.isInteger(1)) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::PowOp if (isa(op) && elementTy.isa()) @@ -401,30 +399,30 @@ // tosa::GreaterOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, CmpFPredicate::OGT, args[0], - args[1]); + return rewriter.create(loc, arith::CmpFPredicate::OGT, + args[0], args[1]); if (isa(op) && elementTy.isSignlessInteger()) - return rewriter.create(loc, CmpIPredicate::sgt, args[0], - args[1]); + return rewriter.create(loc, arith::CmpIPredicate::sgt, + args[0], args[1]); // tosa::GreaterEqualOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, CmpFPredicate::OGE, args[0], - args[1]); + return rewriter.create(loc, arith::CmpFPredicate::OGE, + args[0], args[1]); if (isa(op) && elementTy.isSignlessInteger()) - return rewriter.create(loc, CmpIPredicate::sge, args[0], - args[1]); + return rewriter.create(loc, arith::CmpIPredicate::sge, + args[0], args[1]); // tosa::EqualOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, CmpFPredicate::OEQ, args[0], - args[1]); + return rewriter.create(loc, arith::CmpFPredicate::OEQ, + args[0], args[1]); if (isa(op) && elementTy.isSignlessInteger()) - return rewriter.create(loc, CmpIPredicate::eq, args[0], - args[1]); + return rewriter.create(loc, arith::CmpIPredicate::eq, + args[0], args[1]); // tosa::SelectOp if (isa(op)) { @@ -435,46 +433,46 @@ // tosa::MaximumOp if (isa(op) && elementTy.isa()) { - auto predicate = rewriter.create(loc, CmpFPredicate::OGT, - args[0], args[1]); + auto predicate = rewriter.create( + loc, arith::CmpFPredicate::OGT, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isSignlessInteger()) { - auto predicate = rewriter.create(loc, CmpIPredicate::sgt, - args[0], args[1]); + auto predicate = rewriter.create( + loc, arith::CmpIPredicate::sgt, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } // tosa::MinimumOp if (isa(op) && elementTy.isa()) { - auto predicate = rewriter.create(loc, CmpFPredicate::OLT, - args[0], args[1]); + auto predicate = rewriter.create( + loc, arith::CmpFPredicate::OLT, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isSignlessInteger()) { - auto predicate = rewriter.create(loc, CmpIPredicate::slt, - args[0], args[1]); + auto predicate = rewriter.create( + loc, arith::CmpIPredicate::slt, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } // tosa::CeilOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::FloorOp if (isa(op) && elementTy.isa()) - return rewriter.create(loc, resultTypes, args); + return rewriter.create(loc, resultTypes, args); // tosa::ClampOp if (isa(op) && elementTy.isa()) { - auto min = rewriter.create(loc, elementTy, - op->getAttr("min_fp")); - auto max = rewriter.create(loc, elementTy, - op->getAttr("max_fp")); - return clampHelper(loc, args[0], min, max, CmpFPredicate::OLT, - rewriter); + auto min = rewriter.create(loc, elementTy, + op->getAttr("min_fp")); + auto max = rewriter.create(loc, elementTy, + op->getAttr("max_fp")); + return clampHelper(loc, args[0], min, max, + arith::CmpFPredicate::OLT, rewriter); } if (isa(op) && elementTy.isa()) { @@ -498,41 +496,41 @@ .getSExtValue()); } - auto minVal = - rewriter.create(loc, min, intTy.getIntOrFloatBitWidth()); - auto maxVal = - rewriter.create(loc, max, intTy.getIntOrFloatBitWidth()); - return clampHelper(loc, args[0], minVal, maxVal, - CmpIPredicate::slt, rewriter); + auto minVal = rewriter.create( + loc, min, intTy.getIntOrFloatBitWidth()); + auto maxVal = rewriter.create( + loc, max, intTy.getIntOrFloatBitWidth()); + return clampHelper(loc, args[0], minVal, maxVal, + arith::CmpIPredicate::slt, rewriter); } // tosa::ReluNOp if (isa(op) && elementTy.isa()) { auto zero = - rewriter.create(loc, FloatAttr::get(elementTy, 0)); - auto n = rewriter.create(loc, elementTy, - op->getAttr("max_fp")); - return clampHelper(loc, args[0], zero, n, CmpFPredicate::OLT, - rewriter); + rewriter.create(loc, FloatAttr::get(elementTy, 0)); + auto n = rewriter.create(loc, elementTy, + op->getAttr("max_fp")); + return clampHelper(loc, args[0], zero, n, + arith::CmpFPredicate::OLT, rewriter); } if (isa(op) && elementTy.isa()) { auto zero = - rewriter.create(loc, IntegerAttr::get(elementTy, 0)); + rewriter.create(loc, IntegerAttr::get(elementTy, 0)); auto n = createConstFromIntAttribute(op, "max_int", elementTy, rewriter); - return clampHelper(loc, args[0], zero, n, CmpIPredicate::slt, - rewriter); + return clampHelper(loc, args[0], zero, n, + arith::CmpIPredicate::slt, rewriter); } // tosa::SigmoidOp if (isa(op) && elementTy.isa()) { auto one = - rewriter.create(loc, FloatAttr::get(elementTy, 1)); - auto negate = rewriter.create(loc, resultTypes, args[0]); + rewriter.create(loc, FloatAttr::get(elementTy, 1)); + auto negate = rewriter.create(loc, resultTypes, args[0]); auto exp = rewriter.create(loc, resultTypes, negate); - auto added = rewriter.create(loc, resultTypes, exp, one); - return rewriter.create(loc, resultTypes, one, added); + auto added = rewriter.create(loc, resultTypes, exp, one); + return rewriter.create(loc, resultTypes, one, added); } // tosa::CastOp @@ -546,92 +544,92 @@ return args.front(); if (srcTy.isa() && dstTy.isa() && bitExtend) - return rewriter.create(loc, resultTypes, args, mlir::None); + return rewriter.create(loc, resultTypes, args, mlir::None); if (srcTy.isa() && dstTy.isa() && !bitExtend) - return rewriter.create(loc, resultTypes, args, + return rewriter.create(loc, resultTypes, args, mlir::None); // 1-bit integers need to be treated as signless. - if (srcTy.isInteger(1) && mlir::UIToFPOp::areCastCompatible(srcTy, dstTy)) - return rewriter.create(loc, resultTypes, args, - mlir::None); + if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy)) + return rewriter.create(loc, resultTypes, args, + mlir::None); if (srcTy.isInteger(1) && dstTy.isa() && bitExtend) - return rewriter.create(loc, resultTypes, args, - mlir::None); + return rewriter.create(loc, resultTypes, args, + mlir::None); // All other si-to-fp conversions should be handled by SIToFP. - if (mlir::SIToFPOp::areCastCompatible(srcTy, dstTy)) - return rewriter.create(loc, resultTypes, args, - mlir::None); + if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy)) + return rewriter.create(loc, resultTypes, args, + mlir::None); // Casting to boolean, floats need to only be checked as not-equal to zero. if (srcTy.isa() && dstTy.isInteger(1)) { - Value zero = - rewriter.create(loc, rewriter.getFloatAttr(srcTy, 0.0)); - return rewriter.create(loc, CmpFPredicate::UNE, - args.front(), zero); + Value zero = rewriter.create( + loc, rewriter.getFloatAttr(srcTy, 0.0)); + return rewriter.create(loc, arith::CmpFPredicate::UNE, + args.front(), zero); } - if (mlir::FPToSIOp::areCastCompatible(srcTy, dstTy)) { - auto zero = - rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); - auto half = - rewriter.create(loc, rewriter.getF32FloatAttr(0.5f)); + if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) { + auto zero = rewriter.create( + loc, rewriter.getF32FloatAttr(0.0f)); + auto half = rewriter.create( + loc, rewriter.getF32FloatAttr(0.5f)); - auto intMin = rewriter.create( + auto intMin = rewriter.create( loc, rewriter.getF32FloatAttr( APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue())); - auto intMax = rewriter.create( + auto intMax = rewriter.create( loc, rewriter.getF32FloatAttr( APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue())); - auto added = rewriter.create(loc, args[0], half); - auto subbed = rewriter.create(loc, args[0], half); - auto negative = - rewriter.create(loc, CmpFPredicate::OLT, args[0], zero); + auto added = rewriter.create(loc, args[0], half); + auto subbed = rewriter.create(loc, args[0], half); + auto negative = rewriter.create( + loc, arith::CmpFPredicate::OLT, args[0], zero); auto rounded = rewriter.create(loc, negative, subbed, added); - auto clamped = clampHelper(loc, rounded, intMin, intMax, - CmpFPredicate::OLT, rewriter); + auto clamped = clampHelper( + loc, rounded, intMin, intMax, arith::CmpFPredicate::OLT, rewriter); - return rewriter.create(loc, dstTy, clamped); + return rewriter.create(loc, dstTy, clamped); } // Casting to boolean, integers need to only be checked as not-equal to // zero. if (srcTy.isa() && dstTy.isInteger(1)) { - Value zero = - rewriter.create(loc, 0, srcTy.getIntOrFloatBitWidth()); - return rewriter.create(loc, CmpIPredicate::ne, args.front(), - zero); + Value zero = rewriter.create( + loc, 0, srcTy.getIntOrFloatBitWidth()); + return rewriter.create(loc, arith::CmpIPredicate::ne, + args.front(), zero); } if (srcTy.isa() && dstTy.isa() && bitExtend) - return rewriter.create(loc, resultTypes, args, - mlir::None); + return rewriter.create(loc, resultTypes, args, + mlir::None); if (srcTy.isa() && dstTy.isa() && !bitExtend) { - auto intMin = rewriter.create( + auto intMin = rewriter.create( loc, APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue(), srcTy.getIntOrFloatBitWidth()); - auto intMax = rewriter.create( + auto intMax = rewriter.create( loc, APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue(), srcTy.getIntOrFloatBitWidth()); - auto clamped = clampHelper(loc, args[0], intMin, intMax, - CmpIPredicate::slt, rewriter); - return rewriter.create(loc, dstTy, clamped); + auto clamped = clampHelper( + loc, args[0], intMin, intMax, arith::CmpIPredicate::slt, rewriter); + return rewriter.create(loc, dstTy, clamped); } } @@ -814,50 +812,50 @@ PatternRewriter &rewriter) { Location loc = op->getLoc(); if (isa(op) && elementTy.isa()) { - return rewriter.create(loc, args); + return rewriter.create(loc, args); } if (isa(op) && elementTy.isa()) { - return rewriter.create(loc, args); + return rewriter.create(loc, args); } if (isa(op) && elementTy.isa()) { - return rewriter.create(loc, args); + return rewriter.create(loc, args); } if (isa(op) && elementTy.isa()) { - return rewriter.create(loc, args); + return rewriter.create(loc, args); } if (isa(op) && elementTy.isa()) { - auto predicate = rewriter.create(loc, CmpFPredicate::OLT, - args[0], args[1]); + auto predicate = rewriter.create( + loc, arith::CmpFPredicate::OLT, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isa()) { - auto predicate = rewriter.create(loc, CmpIPredicate::slt, - args[0], args[1]); + auto predicate = rewriter.create( + loc, arith::CmpIPredicate::slt, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isa()) { - auto predicate = rewriter.create(loc, CmpFPredicate::OGT, - args[0], args[1]); + auto predicate = rewriter.create( + loc, arith::CmpFPredicate::OGT, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isa()) { - auto predicate = rewriter.create(loc, CmpIPredicate::sgt, - args[0], args[1]); + auto predicate = rewriter.create( + loc, arith::CmpIPredicate::sgt, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isInteger(1)) - return rewriter.create(loc, args); + return rewriter.create(loc, args); if (isa(op) && elementTy.isInteger(1)) - return rewriter.create(loc, args); + return rewriter.create(loc, args); return {}; } @@ -893,7 +891,7 @@ return rewriter.notifyMatchFailure( op, "No initial value found for reduction operation"); - auto fillValue = rewriter.create(loc, fillValueAttr); + auto fillValue = rewriter.create(loc, fillValueAttr); auto filledTensor = rewriter.create(loc, fillValue, initTensor).result(); @@ -1014,7 +1012,8 @@ weightShape[3], weightShape[0]}; auto weightPermAttr = DenseIntElementsAttr::get( RankedTensorType::get({4}, rewriter.getI64Type()), weightPerm); - Value weightPermValue = rewriter.create(loc, weightPermAttr); + Value weightPermValue = + rewriter.create(loc, weightPermAttr); Type newWeightTy = RankedTensorType::get(newWeightShape, weightTy.getElementType()); weight = rewriter.create(loc, newWeightTy, weight, @@ -1023,7 +1022,7 @@ Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy); Value initTensor = rewriter.create( loc, resultTy.getShape(), resultETy); - Value zero = rewriter.create(loc, resultZeroAttr); + Value zero = rewriter.create(loc, resultZeroAttr); Value zeroTensor = rewriter.create(loc, zero, initTensor).getResult(0); @@ -1057,8 +1056,8 @@ auto kZp = rewriter.getI32IntegerAttr( quantizationInfo.weight_zp().getValue().getSExtValue()); - auto iZpVal = rewriter.create(loc, iZp); - auto kZpVal = rewriter.create(loc, kZp); + auto iZpVal = rewriter.create(loc, iZp); + auto kZpVal = rewriter.create(loc, kZp); Value conv = rewriter .create( @@ -1073,8 +1072,8 @@ indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - Value added = - nestedBuilder.create(loc, args[0], args[1]); + Value added = nestedBuilder.create( + loc, args[0], args[1]); nestedBuilder.create(nestedLoc, added); }) .getResult(0); @@ -1095,8 +1094,8 @@ indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - Value added = - nestedBuilder.create(loc, args[0], args[1]); + Value added = nestedBuilder.create( + loc, args[0], args[1]); nestedBuilder.create(nestedLoc, added); }) .getResult(0); @@ -1205,7 +1204,7 @@ Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy); Value initTensor = rewriter.create( loc, linalgConvTy.getShape(), resultETy); - Value zero = rewriter.create(loc, resultZeroAttr); + Value zero = rewriter.create(loc, resultZeroAttr); Value zeroTensor = rewriter.create(loc, zero, initTensor).getResult(0); @@ -1226,15 +1225,15 @@ getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - Value added = - nestedBuilder.create(loc, args[0], args[1]); + Value added = nestedBuilder.create( + loc, args[0], args[1]); nestedBuilder.create(nestedLoc, added); }) .getResult(0); rewriter.replaceOp(op, result); } else { - auto iZpVal = rewriter.create(loc, iZp); - auto kZpVal = rewriter.create(loc, kZp); + auto iZpVal = rewriter.create(loc, iZp); + auto kZpVal = rewriter.create(loc, kZp); Value conv = rewriter .create( @@ -1250,8 +1249,8 @@ getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - Value added = - nestedBuilder.create(loc, args[0], args[1]); + Value added = nestedBuilder.create( + loc, args[0], args[1]); nestedBuilder.create(nestedLoc, added); }) .getResult(0); @@ -1343,7 +1342,7 @@ auto outputTy = op.getType().cast(); auto outputElementTy = outputTy.getElementType(); auto zeroAttr = rewriter.getZeroAttr(outputElementTy); - Value zero = rewriter.create(loc, zeroAttr); + Value zero = rewriter.create(loc, zeroAttr); auto initTensor = rewriter.create( loc, outputTy.getShape(), outputTy.getElementType()); Value zeroTensor = @@ -1356,10 +1355,10 @@ } auto quantizationInfo = op.quantization_info().getValue(); - auto aZp = rewriter.create( + auto aZp = rewriter.create( loc, rewriter.getI32IntegerAttr( quantizationInfo.a_zp().getValue().getSExtValue())); - auto bZp = rewriter.create( + auto bZp = rewriter.create( loc, rewriter.getI32IntegerAttr( quantizationInfo.b_zp().getValue().getSExtValue())); rewriter.replaceOpWithNewOp( @@ -1404,14 +1403,15 @@ // When quantized, the input elemeny type is not the same as the output Attribute resultZeroAttr = rewriter.getZeroAttr(outputETy); - Value zero = rewriter.create(loc, resultZeroAttr); + Value zero = rewriter.create(loc, resultZeroAttr); Value zeroTensor = rewriter.create(loc, zero, initTensor).getResult(0); SmallVector permutation{1, 0}; auto permutationAttr = DenseIntElementsAttr::get( RankedTensorType::get({2}, rewriter.getI64Type()), permutation); - Value permutationValue = rewriter.create(loc, permutationAttr); + Value permutationValue = + rewriter.create(loc, permutationAttr); SmallVector newWeightShape{weightShape[1], weightShape[0]}; Type newWeightTy = @@ -1439,8 +1439,8 @@ indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - Value added = - nestedBuilder.create(loc, args[0], args[1]); + Value added = nestedBuilder.create( + loc, args[0], args[1]); nestedBuilder.create(nestedLoc, added); }) .getResult(0); @@ -1449,10 +1449,10 @@ } auto quantizationInfo = op.quantization_info().getValue(); - auto inputZp = rewriter.create( + auto inputZp = rewriter.create( loc, rewriter.getI32IntegerAttr( quantizationInfo.input_zp().getValue().getSExtValue())); - auto outputZp = rewriter.create( + auto outputZp = rewriter.create( loc, rewriter.getI32IntegerAttr( quantizationInfo.weight_zp().getValue().getSExtValue())); Value matmul = @@ -1469,8 +1469,8 @@ indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - Value added = - nestedBuilder.create(loc, args[0], args[1]); + Value added = nestedBuilder.create( + loc, args[0], args[1]); nestedBuilder.create(nestedLoc, added); }) .getResult(0); @@ -1674,7 +1674,7 @@ Value multiplierConstant; int64_t multiplierArg = 0; if (multiplierValues.size() == 1) { - multiplierConstant = rewriter.create( + multiplierConstant = rewriter.create( loc, rewriter.getI32IntegerAttr(multiplierValues.front())); } else { SmallVector multiplierExprs{ @@ -1682,7 +1682,7 @@ auto multiplierType = RankedTensorType::get({static_cast(multiplierValues.size())}, rewriter.getI32Type()); - genericInputs.push_back(rewriter.create( + genericInputs.push_back(rewriter.create( loc, DenseIntElementsAttr::get(multiplierType, multiplierValues))); indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, @@ -1697,7 +1697,7 @@ Value shiftConstant; int64_t shiftArg = 0; if (shiftValues.size() == 1) { - shiftConstant = rewriter.create( + shiftConstant = rewriter.create( loc, rewriter.getI8IntegerAttr(shiftValues.front())); } else { SmallVector shiftExprs = { @@ -1705,7 +1705,7 @@ auto shiftType = RankedTensorType::get({static_cast(shiftValues.size())}, rewriter.getIntegerType(8)); - genericInputs.push_back(rewriter.create( + genericInputs.push_back(rewriter.create( loc, DenseIntElementsAttr::get(shiftType, shiftValues))); indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, shiftExprs, @@ -1753,22 +1753,24 @@ valueTy.getIntOrFloatBitWidth()), value) .getResult(0); - value = nestedBuilder.create( + value = nestedBuilder.create( nestedLoc, nestedBuilder.getI32Type(), value); } else { - value = nestedBuilder.create( + value = nestedBuilder.create( nestedLoc, nestedBuilder.getI32Type(), value); } } - value = nestedBuilder.create(nestedLoc, value, inputZp); + value = + nestedBuilder.create(nestedLoc, value, inputZp); value = nestedBuilder.create( loc, nestedBuilder.getI32Type(), value, multiplier, shift, nestedBuilder.getBoolAttr(doubleRound)); // Move to the new zero-point. - value = nestedBuilder.create(nestedLoc, value, outputZp); + value = + nestedBuilder.create(nestedLoc, value, outputZp); // Saturate to the output size. IntegerType outIntType = @@ -1784,19 +1786,17 @@ intMax = APInt::getMaxValue(outBitWidth).getZExtValue(); } - auto intMinVal = nestedBuilder.create( - loc, - nestedBuilder.getIntegerAttr(nestedBuilder.getI32Type(), intMin)); - auto intMaxVal = nestedBuilder.create( - loc, - nestedBuilder.getIntegerAttr(nestedBuilder.getI32Type(), intMax)); + auto intMinVal = nestedBuilder.create( + loc, nestedBuilder.getI32IntegerAttr(intMin)); + auto intMaxVal = nestedBuilder.create( + loc, nestedBuilder.getI32IntegerAttr(intMax)); - value = - clampHelper(nestedLoc, value, intMinVal, intMaxVal, - CmpIPredicate::slt, nestedBuilder); + value = clampHelper( + nestedLoc, value, intMinVal, intMaxVal, arith::CmpIPredicate::slt, + nestedBuilder); if (outIntType.getWidth() < 32) { - value = nestedBuilder.create( + value = nestedBuilder.create( nestedLoc, rewriter.getIntegerType(outIntType.getWidth()), value); @@ -1859,37 +1859,39 @@ Value x = rewriter.create(loc, 2); Value channel = rewriter.create(loc, 3); - auto hwMin = - rewriter.create(loc, rewriter.getI32IntegerAttr(0)); - auto hMax = rewriter.create( + auto hwMin = rewriter.create( + loc, rewriter.getI32IntegerAttr(0)); + auto hMax = rewriter.create( loc, rewriter.getI32IntegerAttr(imageH - 1)); - auto wMax = rewriter.create( + auto wMax = rewriter.create( loc, rewriter.getI32IntegerAttr(imageW - 1)); - Value inY = rewriter.create(loc, rewriter.getI32Type(), y); - Value inX = rewriter.create(loc, rewriter.getI32Type(), x); + Value inY = + rewriter.create(loc, rewriter.getI32Type(), y); + Value inX = + rewriter.create(loc, rewriter.getI32Type(), x); int32_t shift = op.shift(); bool floatingPointMode = shift == 0; Value yStride, xStride, yOffset, xOffset; if (floatingPointMode) { - yStride = rewriter.create(loc, op.stride_fp()[0]); - xStride = rewriter.create(loc, op.stride_fp()[1]); - yOffset = rewriter.create(loc, op.offset_fp()[0]); - xOffset = rewriter.create(loc, op.offset_fp()[1]); + yStride = rewriter.create(loc, op.stride_fp()[0]); + xStride = rewriter.create(loc, op.stride_fp()[1]); + yOffset = rewriter.create(loc, op.offset_fp()[0]); + xOffset = rewriter.create(loc, op.offset_fp()[1]); } else { SmallVector stride, offset; getValuesFromIntArrayAttribute(op.stride(), stride); getValuesFromIntArrayAttribute(op.offset(), offset); - yStride = rewriter.create( + yStride = rewriter.create( loc, rewriter.getI32IntegerAttr(stride[0])); - xStride = rewriter.create( + xStride = rewriter.create( loc, rewriter.getI32IntegerAttr(stride[1])); - yOffset = rewriter.create( + yOffset = rewriter.create( loc, rewriter.getI32IntegerAttr(offset[0])); - xOffset = rewriter.create( + xOffset = rewriter.create( loc, rewriter.getI32IntegerAttr(offset[1])); } @@ -1899,85 +1901,89 @@ // dx = x - ix Value ix, iy, dx, dy; if (floatingPointMode) { - Value y = rewriter.create(loc, rewriter.getF32Type(), inY); - Value x = rewriter.create(loc, rewriter.getF32Type(), inX); + Value y = + rewriter.create(loc, rewriter.getF32Type(), inY); + Value x = + rewriter.create(loc, rewriter.getF32Type(), inX); - y = rewriter.create(loc, y, yStride); - x = rewriter.create(loc, x, xStride); + y = rewriter.create(loc, y, yStride); + x = rewriter.create(loc, x, xStride); - y = rewriter.create(loc, y, yOffset); - x = rewriter.create(loc, x, xOffset); + y = rewriter.create(loc, y, yOffset); + x = rewriter.create(loc, x, xOffset); - iy = rewriter.create(loc, y); - ix = rewriter.create(loc, x); + iy = rewriter.create(loc, y); + ix = rewriter.create(loc, x); - dy = rewriter.create(loc, y, iy); - dx = rewriter.create(loc, x, ix); + dy = rewriter.create(loc, y, iy); + dx = rewriter.create(loc, x, ix); - iy = rewriter.create(loc, rewriter.getI32Type(), iy); - ix = rewriter.create(loc, rewriter.getI32Type(), ix); + iy = rewriter.create(loc, rewriter.getI32Type(), iy); + ix = rewriter.create(loc, rewriter.getI32Type(), ix); } else { - Value shiftVal = - rewriter.create(loc, rewriter.getI32IntegerAttr(shift)); + Value shiftVal = rewriter.create( + loc, rewriter.getI32IntegerAttr(shift)); - Value y = rewriter.create(loc, inY, yStride); - Value x = rewriter.create(loc, inX, xStride); + Value y = rewriter.create(loc, inY, yStride); + Value x = rewriter.create(loc, inX, xStride); - y = rewriter.create(loc, y, yOffset); - x = rewriter.create(loc, x, xOffset); + y = rewriter.create(loc, y, yOffset); + x = rewriter.create(loc, x, xOffset); - iy = rewriter.create(loc, y, shiftVal); - ix = rewriter.create(loc, x, shiftVal); + iy = rewriter.create(loc, y, shiftVal); + ix = rewriter.create(loc, x, shiftVal); - Value yTrunc = rewriter.create(loc, iy, shiftVal); - Value xTrunc = rewriter.create(loc, ix, shiftVal); + Value yTrunc = rewriter.create(loc, iy, shiftVal); + Value xTrunc = rewriter.create(loc, ix, shiftVal); - dy = rewriter.create(loc, y, yTrunc); - dx = rewriter.create(loc, x, xTrunc); + dy = rewriter.create(loc, y, yTrunc); + dx = rewriter.create(loc, x, xTrunc); } if (op.mode() == "NEAREST_NEIGHBOR") { Value yPred, xPred; // Round the index position towards the closest pixel location. if (floatingPointMode) { - auto halfVal = - rewriter.create(loc, rewriter.getF32FloatAttr(0.5f)); - yPred = rewriter.create(loc, CmpFPredicate::OGE, dy, - halfVal); - xPred = rewriter.create(loc, CmpFPredicate::OGE, dx, - halfVal); + auto halfVal = rewriter.create( + loc, rewriter.getF32FloatAttr(0.5f)); + yPred = rewriter.create(loc, arith::CmpFPredicate::OGE, + dy, halfVal); + xPred = rewriter.create(loc, arith::CmpFPredicate::OGE, + dx, halfVal); } else { - auto halfVal = rewriter.create( + auto halfVal = rewriter.create( loc, rewriter.getI32IntegerAttr(1 << (shift - 1))); - yPred = rewriter.create(loc, CmpIPredicate::sge, dy, - halfVal); - xPred = rewriter.create(loc, CmpIPredicate::sge, dx, - halfVal); + yPred = rewriter.create(loc, arith::CmpIPredicate::sge, + dy, halfVal); + xPred = rewriter.create(loc, arith::CmpIPredicate::sge, + dx, halfVal); } - auto zeroVal = - rewriter.create(loc, rewriter.getI32IntegerAttr(0)); - auto oneVal = - rewriter.create(loc, rewriter.getI32IntegerAttr(1)); + auto zeroVal = rewriter.create( + loc, rewriter.getI32IntegerAttr(0)); + auto oneVal = rewriter.create( + loc, rewriter.getI32IntegerAttr(1)); auto yOffset = rewriter.create(loc, yPred, oneVal, zeroVal); auto xOffset = rewriter.create(loc, xPred, oneVal, zeroVal); - iy = rewriter.create(loc, iy, yOffset); - ix = rewriter.create(loc, ix, xOffset); + iy = rewriter.create(loc, iy, yOffset); + ix = rewriter.create(loc, ix, xOffset); // Clamp the to be within the bounds of the input image. - iy = clampHelper(loc, iy, hwMin, hMax, CmpIPredicate::slt, - rewriter); - ix = clampHelper(loc, ix, hwMin, wMax, CmpIPredicate::slt, - rewriter); + iy = clampHelper(loc, iy, hwMin, hMax, + arith::CmpIPredicate::slt, rewriter); + ix = clampHelper(loc, ix, hwMin, wMax, + arith::CmpIPredicate::slt, rewriter); // Read the value from the input array. - iy = rewriter.create(loc, rewriter.getIndexType(), iy); - ix = rewriter.create(loc, rewriter.getIndexType(), ix); + iy = rewriter.create(loc, rewriter.getIndexType(), + iy); + ix = rewriter.create(loc, rewriter.getIndexType(), + ix); Value result = rewriter.create( loc, input, ValueRange{batch, iy, ix, channel}); @@ -1991,25 +1997,29 @@ Value y0 = iy; Value x0 = ix; - auto oneVal = - rewriter.create(loc, rewriter.getI32IntegerAttr(1)); - Value y1 = rewriter.create(loc, y0, oneVal); - Value x1 = rewriter.create(loc, x0, oneVal); - - y0 = clampHelper(loc, y0, hwMin, hMax, CmpIPredicate::slt, - rewriter); - y1 = clampHelper(loc, y1, hwMin, hMax, CmpIPredicate::slt, - rewriter); - - x0 = clampHelper(loc, x0, hwMin, wMax, CmpIPredicate::slt, - rewriter); - x1 = clampHelper(loc, x1, hwMin, wMax, CmpIPredicate::slt, - rewriter); - - y0 = rewriter.create(loc, rewriter.getIndexType(), y0); - y1 = rewriter.create(loc, rewriter.getIndexType(), y1); - x0 = rewriter.create(loc, rewriter.getIndexType(), x0); - x1 = rewriter.create(loc, rewriter.getIndexType(), x1); + auto oneVal = rewriter.create( + loc, rewriter.getI32IntegerAttr(1)); + Value y1 = rewriter.create(loc, y0, oneVal); + Value x1 = rewriter.create(loc, x0, oneVal); + + y0 = clampHelper(loc, y0, hwMin, hMax, + arith::CmpIPredicate::slt, rewriter); + y1 = clampHelper(loc, y1, hwMin, hMax, + arith::CmpIPredicate::slt, rewriter); + + x0 = clampHelper(loc, x0, hwMin, wMax, + arith::CmpIPredicate::slt, rewriter); + x1 = clampHelper(loc, x1, hwMin, wMax, + arith::CmpIPredicate::slt, rewriter); + + y0 = rewriter.create(loc, rewriter.getIndexType(), + y0); + y1 = rewriter.create(loc, rewriter.getIndexType(), + y1); + x0 = rewriter.create(loc, rewriter.getIndexType(), + x0); + x1 = rewriter.create(loc, rewriter.getIndexType(), + x1); Value y0x0 = rewriter.create( loc, input, ValueRange{batch, y0, x0, channel}); @@ -2021,56 +2031,58 @@ loc, input, ValueRange{batch, y1, x1, channel}); if (floatingPointMode) { - auto oneVal = - rewriter.create(loc, rewriter.getF32FloatAttr(1.f)); + auto oneVal = rewriter.create( + loc, rewriter.getF32FloatAttr(1.f)); Value rightPart = dx; - Value leftPart = rewriter.create(loc, oneVal, dx); + Value leftPart = rewriter.create(loc, oneVal, dx); - y0x0 = rewriter.create(loc, y0x0, leftPart); - y0x1 = rewriter.create(loc, y0x1, rightPart); - Value topAcc = rewriter.create(loc, y0x0, y0x1); + y0x0 = rewriter.create(loc, y0x0, leftPart); + y0x1 = rewriter.create(loc, y0x1, rightPart); + Value topAcc = rewriter.create(loc, y0x0, y0x1); - y1x0 = rewriter.create(loc, y1x0, leftPart); - y1x1 = rewriter.create(loc, y1x1, rightPart); - Value bottomAcc = rewriter.create(loc, y1x0, y1x1); + y1x0 = rewriter.create(loc, y1x0, leftPart); + y1x1 = rewriter.create(loc, y1x1, rightPart); + Value bottomAcc = rewriter.create(loc, y1x0, y1x1); Value bottomPart = dy; - Value topPart = rewriter.create(loc, oneVal, dy); - topAcc = rewriter.create(loc, topAcc, topPart); - bottomAcc = rewriter.create(loc, bottomAcc, bottomPart); - Value result = rewriter.create(loc, topAcc, bottomAcc); + Value topPart = rewriter.create(loc, oneVal, dy); + topAcc = rewriter.create(loc, topAcc, topPart); + bottomAcc = + rewriter.create(loc, bottomAcc, bottomPart); + Value result = rewriter.create(loc, topAcc, bottomAcc); rewriter.create(loc, result); return success(); } else { - y0x0 = rewriter.create(loc, resultElementTy, y0x0); - y0x1 = rewriter.create(loc, resultElementTy, y0x1); - y1x0 = rewriter.create(loc, resultElementTy, y1x0); - y1x1 = rewriter.create(loc, resultElementTy, y1x1); + y0x0 = rewriter.create(loc, resultElementTy, y0x0); + y0x1 = rewriter.create(loc, resultElementTy, y0x1); + y1x0 = rewriter.create(loc, resultElementTy, y1x0); + y1x1 = rewriter.create(loc, resultElementTy, y1x1); if (resultElementTy.getIntOrFloatBitWidth() > 32) { - dx = rewriter.create(loc, resultElementTy, dx); - dy = rewriter.create(loc, resultElementTy, dy); + dx = rewriter.create(loc, resultElementTy, dx); + dy = rewriter.create(loc, resultElementTy, dy); } - auto unitVal = rewriter.create( + auto unitVal = rewriter.create( loc, rewriter.getIntegerAttr(resultElementTy, 1 << shift)); Value rightPart = dx; - Value leftPart = rewriter.create(loc, unitVal, dx); + Value leftPart = rewriter.create(loc, unitVal, dx); - y0x0 = rewriter.create(loc, y0x0, leftPart); - y0x1 = rewriter.create(loc, y0x1, rightPart); - Value topAcc = rewriter.create(loc, y0x0, y0x1); + y0x0 = rewriter.create(loc, y0x0, leftPart); + y0x1 = rewriter.create(loc, y0x1, rightPart); + Value topAcc = rewriter.create(loc, y0x0, y0x1); - y1x0 = rewriter.create(loc, y1x0, leftPart); - y1x1 = rewriter.create(loc, y1x1, rightPart); - Value bottomAcc = rewriter.create(loc, y1x0, y1x1); + y1x0 = rewriter.create(loc, y1x0, leftPart); + y1x1 = rewriter.create(loc, y1x1, rightPart); + Value bottomAcc = rewriter.create(loc, y1x0, y1x1); Value bottomPart = dy; - Value topPart = rewriter.create(loc, unitVal, dy); - topAcc = rewriter.create(loc, topAcc, topPart); - bottomAcc = rewriter.create(loc, bottomAcc, bottomPart); - Value result = rewriter.create(loc, topAcc, bottomAcc); + Value topPart = rewriter.create(loc, unitVal, dy); + topAcc = rewriter.create(loc, topAcc, topPart); + bottomAcc = + rewriter.create(loc, bottomAcc, bottomPart); + Value result = rewriter.create(loc, topAcc, bottomAcc); rewriter.create(loc, result); return success(); @@ -2125,12 +2137,12 @@ Location loc = op.getLoc(); int axis = op.axis(); Value axisValue = - rewriter.create(loc, rewriter.getIndexAttr(axis)); + rewriter.create(loc, rewriter.getIndexAttr(axis)); int rank = resultType.getRank(); SmallVector offsets, sizes, strides; sizes.reserve(rank); - strides.resize(rank, rewriter.create(loc, 1)); - offsets.resize(rank, rewriter.create(loc, 0)); + strides.resize(rank, rewriter.create(loc, 1)); + offsets.resize(rank, rewriter.create(loc, 0)); for (int i = 0; i < rank; ++i) { sizes.push_back( @@ -2140,14 +2152,14 @@ Value resultDimSize = sizes[axis]; for (auto arg : adaptor.getOperands().drop_front()) { auto size = rewriter.create(loc, arg, axisValue); - resultDimSize = rewriter.create(loc, resultDimSize, size); + resultDimSize = rewriter.create(loc, resultDimSize, size); } sizes[axis] = resultDimSize; Value init = rewriter.create( loc, resultType.getShape(), resultType.getElementType()); - Value zeroVal = rewriter.create( + Value zeroVal = rewriter.create( loc, rewriter.getZeroAttr(resultType.getElementType())); Value result = rewriter.create(loc, zeroVal, init).getResult(0); @@ -2156,7 +2168,8 @@ sizes[axis] = rewriter.create(loc, arg, axisValue); result = rewriter.create(loc, arg, result, offsets, sizes, strides); - offsets[axis] = rewriter.create(loc, offsets[axis], sizes[axis]); + offsets[axis] = + rewriter.create(loc, offsets[axis], sizes[axis]); } rewriter.replaceOp(op, result); return success(); @@ -2202,10 +2215,11 @@ auto index = rewriter.create(nestedLoc, i).getResult(); if (i == axis) { - auto one = rewriter.create(nestedLoc, 1); + auto one = rewriter.create(nestedLoc, 1); auto sizeMinusOne = - rewriter.create(nestedLoc, axisDimSize, one); - index = rewriter.create(nestedLoc, sizeMinusOne, index); + rewriter.create(nestedLoc, axisDimSize, one); + index = rewriter.create(nestedLoc, sizeMinusOne, + index); } indices.push_back(index); @@ -2319,9 +2333,10 @@ "tosa.pad to linalg lowering encountered an unknown element type"); } - Value lowIndex = rewriter.create(loc, rewriter.getIndexAttr(0)); + Value lowIndex = + rewriter.create(loc, rewriter.getIndexAttr(0)); Value highIndex = - rewriter.create(loc, rewriter.getIndexAttr(1)); + rewriter.create(loc, rewriter.getIndexAttr(1)); SmallVector lowValues; SmallVector highValues; @@ -2330,22 +2345,22 @@ highValues.reserve(rank); for (int i = 0; i < rank; i++) { - Value inputIndex = rewriter.createOrFold(loc, i); + Value inputIndex = rewriter.createOrFold(loc, i); Value lowVal = rewriter.createOrFold( loc, padding, ValueRange({inputIndex, lowIndex})); Value highVal = rewriter.createOrFold( loc, padding, ValueRange({inputIndex, highIndex})); - lowVal = rewriter.createOrFold(loc, rewriter.getIndexType(), - lowVal); - highVal = rewriter.createOrFold(loc, rewriter.getIndexType(), - highVal); + lowVal = rewriter.createOrFold( + loc, rewriter.getIndexType(), lowVal); + highVal = rewriter.createOrFold( + loc, rewriter.getIndexType(), highVal); lowValues.push_back(lowVal); highValues.push_back(highVal); } - Value constant = rewriter.create(loc, constantAttr); + Value constant = rewriter.create(loc, constantAttr); auto newPadOp = linalg::PadTensorOp::createPadScalarOp( padOp.getType(), input, constant, lowValues, highValues, @@ -2400,7 +2415,7 @@ .create(loc, ArrayRef({}), resultTy.getShape(), outElementTy) .result(); - auto fillValueIdx = rewriter.create( + auto fillValueIdx = rewriter.create( loc, rewriter.getIntegerAttr(outElementTy, 0)); auto filledTensorIdx = rewriter.create(loc, fillValueIdx, initTensorIdx) @@ -2419,7 +2434,8 @@ return rewriter.notifyMatchFailure( argmaxOp, "unsupported tosa.argmax element type"); - auto fillValueMax = rewriter.create(loc, fillValueMaxAttr); + auto fillValueMax = + rewriter.create(loc, fillValueMaxAttr); auto filledTensorMax = rewriter.create(loc, fillValueMax, initTensorMax) .result(); @@ -2449,17 +2465,17 @@ auto oldIndex = blockArgs[1]; auto oldValue = blockArgs[2]; - Value newIndex = rewriter.create( + Value newIndex = rewriter.create( nestedLoc, oldIndex.getType(), rewriter.create(loc, axis)); Value predicate; if (inElementTy.isa()) { - predicate = rewriter.create( - nestedLoc, CmpFPredicate::OGT, newValue, oldValue); + predicate = rewriter.create( + nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue); } else if (inElementTy.isa()) { - predicate = rewriter.create( - nestedLoc, CmpIPredicate::sgt, newValue, oldValue); + predicate = rewriter.create( + nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue); } else { didEncounterError = true; return; @@ -2523,7 +2539,7 @@ [&](OpBuilder &b, Location loc, ValueRange args) { auto indexValue = args[0]; auto index0 = rewriter.create(loc, 0); - Value index1 = rewriter.create( + Value index1 = rewriter.create( loc, rewriter.getIndexType(), indexValue); auto index2 = rewriter.create(loc, 2); Value extract = rewriter.create( @@ -2584,11 +2600,11 @@ rewriter.setInsertionPointToStart(block); if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) && resultElementTy.isInteger(8)) { - Value index = rewriter.create(loc, rewriter.getIndexType(), - inputValue); - Value offset = rewriter.create(loc, 128); - index = rewriter.create(loc, rewriter.getIndexType(), index, - offset); + Value index = rewriter.create( + loc, rewriter.getIndexType(), inputValue); + Value offset = rewriter.create(loc, 128); + index = rewriter.create(loc, rewriter.getIndexType(), + index, offset); Value extract = rewriter.create(loc, table, ValueRange{index}); rewriter.create(loc, extract); @@ -2597,35 +2613,35 @@ if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) && resultElementTy.isInteger(32)) { - Value extend = rewriter.create( + Value extend = rewriter.create( loc, rewriter.getI32Type(), inputValue); - auto offset = - rewriter.create(loc, rewriter.getI32IntegerAttr(32768)); - auto seven = - rewriter.create(loc, rewriter.getI32IntegerAttr(7)); - auto one = - rewriter.create(loc, rewriter.getI32IntegerAttr(1)); - auto b1111111 = - rewriter.create(loc, rewriter.getI32IntegerAttr(127)); + auto offset = rewriter.create( + loc, rewriter.getI32IntegerAttr(32768)); + auto seven = rewriter.create( + loc, rewriter.getI32IntegerAttr(7)); + auto one = rewriter.create( + loc, rewriter.getI32IntegerAttr(1)); + auto b1111111 = rewriter.create( + loc, rewriter.getI32IntegerAttr(127)); // Compute the index and fractional part from the input value: // value = value + 32768 // index = value >> 7; // fraction = 0x01111111 & value - auto extendAdd = rewriter.create(loc, extend, offset); - Value index = - rewriter.create(loc, extendAdd, seven); - Value fraction = rewriter.create(loc, extendAdd, b1111111); + auto extendAdd = rewriter.create(loc, extend, offset); + Value index = rewriter.create(loc, extendAdd, seven); + Value fraction = + rewriter.create(loc, extendAdd, b1111111); // Extract the base and next values from the table. // base = (int32_t) table[index]; // next = (int32_t) table[index + 1]; - Value indexPlusOne = rewriter.create(loc, index, one); + Value indexPlusOne = rewriter.create(loc, index, one); - index = - rewriter.create(loc, rewriter.getIndexType(), index); - indexPlusOne = rewriter.create( + index = rewriter.create( + loc, rewriter.getIndexType(), index); + indexPlusOne = rewriter.create( loc, rewriter.getIndexType(), indexPlusOne); Value base = @@ -2633,15 +2649,18 @@ Value next = rewriter.create( loc, table, ValueRange{indexPlusOne}); - base = rewriter.create(loc, rewriter.getI32Type(), base); - next = rewriter.create(loc, rewriter.getI32Type(), next); + base = + rewriter.create(loc, rewriter.getI32Type(), base); + next = + rewriter.create(loc, rewriter.getI32Type(), next); // Use the fractional part to interpolate between the input values: // result = (base << 7) + (next - base) * fraction - Value baseScaled = rewriter.create(loc, base, seven); - Value diff = rewriter.create(loc, next, base); - Value diffScaled = rewriter.create(loc, diff, fraction); - Value result = rewriter.create(loc, baseScaled, diffScaled); + Value baseScaled = rewriter.create(loc, base, seven); + Value diff = rewriter.create(loc, next, base); + Value diffScaled = rewriter.create(loc, diff, fraction); + Value result = + rewriter.create(loc, baseScaled, diffScaled); rewriter.create(loc, result); @@ -2694,7 +2713,7 @@ pad.resize(pad.size() + 2, 0); Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter); - Value initialValue = rewriter.create(loc, initialAttr); + Value initialValue = rewriter.create(loc, initialAttr); SmallVector kernel, stride; getValuesFromIntArrayAttribute(op.kernel(), kernel); @@ -2749,7 +2768,7 @@ Attribute initialAttr = rewriter.getZeroAttr(accETy); Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter); - Value initialValue = rewriter.create(loc, initialAttr); + Value initialValue = rewriter.create(loc, initialAttr); SmallVector kernel, stride; getValuesFromIntArrayAttribute(op.kernel(), kernel); @@ -2791,18 +2810,18 @@ ArrayRef({affineMap, affineMap}), getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &b, Location loc, ValueRange args) { - auto zero = rewriter.create(loc, 0); - auto one = rewriter.create(loc, 1); - auto iH = rewriter.create( + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + auto iH = rewriter.create( loc, poolingOpTy.getDimSize(1) - 1); - auto iW = rewriter.create( + auto iW = rewriter.create( loc, poolingOpTy.getDimSize(2) - 1); // Compute the indices from either end. auto y0 = rewriter.create(loc, 1); auto x0 = rewriter.create(loc, 2); - auto y1 = rewriter.create(loc, iH, y0); - auto x1 = rewriter.create(loc, iW, x0); + auto y1 = rewriter.create(loc, iH, y0); + auto x1 = rewriter.create(loc, iW, x0); // Determines what the portion of valid input is covered by the // kernel. @@ -2810,34 +2829,34 @@ if (pad == 0) return v; - auto padVal = rewriter.create(loc, pad); - Value dx = rewriter.create(loc, x, padVal); + auto padVal = rewriter.create(loc, pad); + Value dx = rewriter.create(loc, x, padVal); - Value cmp = rewriter.create(loc, CmpIPredicate::slt, - dx, zero); + Value cmp = rewriter.create( + loc, arith::CmpIPredicate::slt, dx, zero); Value offset = rewriter.create(loc, cmp, dx, zero); - return rewriter.create(loc, v, offset)->getResult(0); + return rewriter.create(loc, v, offset)->getResult(0); }; // Compute the vertical component of coverage. - auto kH0 = rewriter.create(loc, kernel[0]); + auto kH0 = rewriter.create(loc, kernel[0]); auto kH1 = padFn(kH0, y0, pad[2]); auto kH2 = padFn(kH1, y1, pad[3]); - auto kHCmp = - rewriter.create(loc, CmpIPredicate::slt, kH2, one); + auto kHCmp = rewriter.create( + loc, arith::CmpIPredicate::slt, kH2, one); auto kH3 = rewriter.create(loc, kHCmp, one, kH2); // compute the horizontal component of coverage. - auto kW0 = rewriter.create(loc, kernel[1]); + auto kW0 = rewriter.create(loc, kernel[1]); auto kW1 = padFn(kW0, x0, pad[4]); auto kW2 = padFn(kW1, x1, pad[5]); - auto kWCmp = - rewriter.create(loc, CmpIPredicate::slt, kW2, one); + auto kWCmp = rewriter.create( + loc, arith::CmpIPredicate::slt, kW2, one); auto kW3 = rewriter.create(loc, kWCmp, one, kW2); // Compute the total number of elements and normalize. - Value count = rewriter.create(loc, kH3, kW3); - auto countI = rewriter.create( + Value count = rewriter.create(loc, kH3, kW3); + auto countI = rewriter.create( loc, rewriter.getI32Type(), count); // Divide by the number of summed values. For floats this is just @@ -2846,20 +2865,21 @@ Value poolVal = args[0]; if (accETy.isa()) { auto countF = - rewriter.create(loc, inElementTy, countI); - poolVal = - rewriter.create(loc, poolVal, countF)->getResult(0); + rewriter.create(loc, inElementTy, countI); + poolVal = rewriter.create(loc, poolVal, countF) + ->getResult(0); } else { // If we have quantization information we need to apply an offset // for the input zp value. if (op.quantization_info()) { auto quantizationInfo = op.quantization_info().getValue(); - auto inputZp = rewriter.create( + auto inputZp = rewriter.create( loc, quantizationInfo.input_zp()); Value offset = - rewriter.create(loc, accETy, countI, inputZp); - poolVal = rewriter.create(loc, accETy, poolVal, offset); + rewriter.create(loc, accETy, countI, inputZp); + poolVal = + rewriter.create(loc, accETy, poolVal, offset); } // Compute the multiplier and shift values for the quantization @@ -2869,14 +2889,14 @@ int64_t numerator = ((1 << 30) + 1); int64_t shift = 30; - Value numeratorVal = rewriter.create( + Value numeratorVal = rewriter.create( loc, rewriter.getI32IntegerAttr(numerator)); Value multiplierVal = rewriter - .create(loc, rewriter.getI32Type(), + .create(loc, rewriter.getI32Type(), numeratorVal, countI) .getResult(); - Value shiftVal = rewriter.create( + Value shiftVal = rewriter.create( loc, rewriter.getI8IntegerAttr(shift)); auto scaled = @@ -2890,28 +2910,26 @@ // zeropoint. if (op.quantization_info()) { auto quantizationInfo = op.quantization_info().getValue(); - auto outputZp = rewriter.create( + auto outputZp = rewriter.create( loc, quantizationInfo.output_zp()); - scaled = - rewriter.create(loc, scaled, outputZp).getResult(); + scaled = rewriter.create(loc, scaled, outputZp) + .getResult(); } // Apply Clip. int64_t outBitwidth = resultETy.getIntOrFloatBitWidth(); - auto min = rewriter.create( - loc, rewriter.getIntegerAttr( - accETy, - APInt::getSignedMinValue(outBitwidth).getSExtValue())); - auto max = rewriter.create( - loc, rewriter.getIntegerAttr( - accETy, - APInt::getSignedMaxValue(outBitwidth).getSExtValue())); - auto clamp = clampHelper( - loc, scaled, min, max, CmpIPredicate::slt, rewriter); + auto min = rewriter.create( + loc, APInt::getSignedMinValue(outBitwidth).getSExtValue(), + accETy); + auto max = rewriter.create( + loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(), + accETy); + auto clamp = clampHelper( + loc, scaled, min, max, arith::CmpIPredicate::slt, rewriter); // Convert type. - poolVal = rewriter.create(loc, resultETy, clamp); + poolVal = rewriter.create(loc, resultETy, clamp); } // Cast to output type. diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -12,6 +12,7 @@ #include "../PassDetail.h" #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/SCF.h" @@ -33,9 +34,9 @@ : public TosaToLinalgOnTensorsBase { public: void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert(); + registry.insert(); } void runOnFunction() override { diff --git a/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt b/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt --- a/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt +++ b/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt @@ -10,6 +10,7 @@ MLIRConversionPassIncGen LINK_LIBS PUBLIC + MLIRArithmetic MLIRIR MLIRStandard MLIRPass diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp --- a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp +++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/TosaToStandard/TosaToStandard.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" @@ -28,7 +29,7 @@ LogicalResult matchAndRewrite(tosa::ConstOp op, PatternRewriter &rewriter) const final { - rewriter.replaceOpWithNewOp<::ConstantOp>(op, op.value()); + rewriter.replaceOpWithNewOp(op, op.value()); return success(); } }; @@ -67,12 +68,12 @@ bool doubleRound = op.double_round(); Type inType = op.value().getType(); - Value one8 = rewriter.create( + Value one8 = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getIntegerType(8), 1)); - Value one64 = rewriter.create( + Value one64 = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1)); - Value shiftSubOne8 = rewriter.create(loc, shift8, one8); + Value shiftSubOne8 = rewriter.create(loc, shift8, one8); // The rounding value semantics below equate to the following code: // int64_t round = 1 << (shift - 1); @@ -83,45 +84,45 @@ // // Note that minimal bitwidth operators are used throughout the block. - Value round64 = rewriter.create( + Value round64 = rewriter.create( loc, one64, - rewriter.create(loc, rewriter.getI64Type(), - shiftSubOne8)); + rewriter.create(loc, rewriter.getI64Type(), + shiftSubOne8)); // Double rounding is performing a round operation before the shift if (doubleRound) { - Value one32 = rewriter.create( + Value one32 = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 1)); - Value shift32 = rewriter.create( - loc, rewriter.getI32Type(), shift8); - Value thirty32 = rewriter.create( + Value shift32 = + rewriter.create(loc, rewriter.getI32Type(), shift8); + Value thirty32 = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 30)); Value shiftThirty32 = - rewriter.create(loc, one32, thirty32); - Value shiftThirty64 = rewriter.create( + rewriter.create(loc, one32, thirty32); + Value shiftThirty64 = rewriter.create( loc, rewriter.getI64Type(), shiftThirty32); // Round value needs to with be added or subtracted depending on the sign // of the input value. Value roundAdd64 = - rewriter.create(loc, round64, shiftThirty64); + rewriter.create(loc, round64, shiftThirty64); Value roundSub64 = - rewriter.create(loc, round64, shiftThirty64); + rewriter.create(loc, round64, shiftThirty64); Value zero32 = - rewriter.create(loc, rewriter.getZeroAttr(inType)); - Value valueGreaterThanZero = rewriter.create( - loc, CmpIPredicate::sge, value32, zero32); + rewriter.create(loc, rewriter.getZeroAttr(inType)); + Value valueGreaterThanZero = rewriter.create( + loc, arith::CmpIPredicate::sge, value32, zero32); Value doubleRound64 = rewriter.create( loc, valueGreaterThanZero, roundAdd64, roundSub64); // We only perform double rounding if the shift value is greater than 32. - Value thirtyTwo32 = rewriter.create( + Value thirtyTwo32 = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 32)); - Value shiftGreaterThanThirtyTwo = rewriter.create( - loc, CmpIPredicate::sge, shift32, thirtyTwo32); + Value shiftGreaterThanThirtyTwo = rewriter.create( + loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32); round64 = rewriter.create(loc, shiftGreaterThanThirtyTwo, doubleRound64, round64); } @@ -133,20 +134,19 @@ // Note that multiply and shift need to be perform in i64 to preserve bits. Value value64 = - rewriter.create(loc, rewriter.getI64Type(), value32); - Value multiplier64 = rewriter.create( + rewriter.create(loc, rewriter.getI64Type(), value32); + Value multiplier64 = rewriter.create( loc, rewriter.getI64Type(), multiplier32); Value shift64 = - rewriter.create(loc, rewriter.getI64Type(), shift8); + rewriter.create(loc, rewriter.getI64Type(), shift8); // Multiply as a pair of i64 values to guarantee the end value fits. - Value result64 = rewriter.create(loc, value64, multiplier64); - result64 = rewriter.create(loc, result64, round64); - result64 = - rewriter.create(loc, result64, shift64); + Value result64 = rewriter.create(loc, value64, multiplier64); + result64 = rewriter.create(loc, result64, round64); + result64 = rewriter.create(loc, result64, shift64); - Value result32 = rewriter.create( - loc, rewriter.getI32Type(), result64); + Value result32 = + rewriter.create(loc, rewriter.getI32Type(), result64); rewriter.replaceOp(op, result32); return success(); diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp --- a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp +++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp @@ -12,6 +12,7 @@ #include "../PassDetail.h" #include "mlir/Conversion/TosaToStandard/TosaToStandard.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" @@ -34,6 +35,7 @@ target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); diff --git a/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt b/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt @@ -8,6 +8,7 @@ Core LINK_LIBS PUBLIC + MLIRArithmetic MLIRGPUOps MLIRLLVMIR MLIRMemRef diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -16,6 +16,7 @@ #include "../PassDetail.h" #include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" @@ -116,7 +117,7 @@ /// Return true if the constant is a splat to a 2D vector so that it can be /// converted to a MMA constant matrix op. -static bool constantSupportsMMAMatrixType(ConstantOp constantOp) { +static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) { auto vecType = constantOp.getType().dyn_cast(); if (!vecType || vecType.getRank() != 2) return false; @@ -138,7 +139,7 @@ return transferWriteSupportsMMAMatrixType(transferWrite); if (auto contract = dyn_cast(op)) return contractSupportsMMAMatrixType(contract); - if (auto constant = dyn_cast(op)) + if (auto constant = dyn_cast(op)) return constantSupportsMMAMatrixType(constant); if (auto broadcast = dyn_cast(op)) return broadcastSupportsMMAMatrixType(broadcast); @@ -324,13 +325,13 @@ } /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. -static void convertConstantOp(ConstantOp op, +static void convertConstantOp(arith::ConstantOp op, llvm::DenseMap &valueMapping) { assert(constantSupportsMMAMatrixType(op)); OpBuilder b(op); - Attribute splat = op.getValue().cast().getSplatValue(); + Attribute splat = op.value().cast().getSplatValue(); auto scalarConstant = - b.create(op.getLoc(), splat.getType(), splat); + b.create(op.getLoc(), splat.getType(), splat); const char *fragType = inferFragType(op); auto vecType = op.getType().cast(); gpu::MMAMatrixType type = gpu::MMAMatrixType::get( @@ -439,7 +440,7 @@ convertTransferWriteOp(transferWrite, valueMapping); } else if (auto contractOp = dyn_cast(op)) { convertContractOp(contractOp, valueMapping); - } else if (auto constantOp = dyn_cast(op)) { + } else if (auto constantOp = dyn_cast(op)) { convertConstantOp(constantOp, valueMapping); } else if (auto broadcastOp = dyn_cast(op)) { convertBroadcastOp(broadcastOp, valueMapping); diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -13,6 +13,7 @@ Core LINK_LIBS PUBLIC + MLIRArithmetic MLIRArmNeon MLIRArmSVE MLIRArmSVETransforms diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -59,7 +60,7 @@ return rewriter.create(loc, from, into, offset); return rewriter.create( loc, vectorType, from, into, - rewriter.create(loc, offset)); + rewriter.create(loc, offset)); } // Helper that picks the proper sequence for extracting. @@ -86,7 +87,7 @@ return rewriter.create(loc, vector, offset); return rewriter.create( loc, vectorType.getElementType(), vector, - rewriter.create(loc, offset)); + rewriter.create(loc, offset)); } // Helper that returns a subset of `arrayAttr` as a vector of int64_t. @@ -795,8 +796,8 @@ auto loc = op.getLoc(); auto elemType = vType.getElementType(); - Value zero = rewriter.create(loc, elemType, - rewriter.getZeroAttr(elemType)); + Value zero = rewriter.create( + loc, elemType, rewriter.getZeroAttr(elemType)); Value desc = rewriter.create(loc, vType, zero); for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { Value extrLHS = rewriter.create(loc, op.lhs(), i); @@ -1144,11 +1145,11 @@ if (rank == 0) { switch (conversion) { case PrintConversion::ZeroExt64: - value = rewriter.create( + value = rewriter.create( loc, value, IntegerType::get(rewriter.getContext(), 64)); break; case PrintConversion::SignExt64: - value = rewriter.create( + value = rewriter.create( loc, value, IntegerType::get(rewriter.getContext(), 64)); break; case PrintConversion::None: @@ -1231,8 +1232,8 @@ } // Extract/insert on a lower ranked extract strided slice op. - Value zero = rewriter.create(loc, elemType, - rewriter.getZeroAttr(elemType)); + Value zero = rewriter.create( + loc, elemType, rewriter.getZeroAttr(elemType)); Value res = rewriter.create(loc, dstType, zero); for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; off += stride, ++idx) { diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/AMX/AMXDialect.h" #include "mlir/Dialect/AMX/Transforms.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms.h" @@ -42,6 +43,7 @@ // Override explicitly to allow conditional dialect dependence. void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); + registry.insert(); registry.insert(); if (enableArmNeon) registry.insert(); @@ -81,6 +83,7 @@ // Architecture specific augmentations. LLVMConversionTarget target(getContext()); + target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); target.addLegalOp(); diff --git a/mlir/lib/Conversion/VectorToSCF/CMakeLists.txt b/mlir/lib/Conversion/VectorToSCF/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToSCF/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToSCF/CMakeLists.txt @@ -8,6 +8,7 @@ Core LINK_LIBS PUBLIC + MLIRArithmetic MLIRLLVMIR MLIRMemRef MLIRTransforms diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -17,6 +17,7 @@ #include "../PassDetail.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Vector/VectorOps.h" @@ -123,8 +124,8 @@ return Value(); Location loc = xferOp.getLoc(); - Value ivI32 = - b.create(loc, IntegerType::get(b.getContext(), 32), iv); + Value ivI32 = b.create( + loc, IntegerType::get(b.getContext(), 32), iv); return b.create(loc, xferOp.mask(), ivI32); } @@ -171,13 +172,14 @@ bindDims(xferOp.getContext(), d0, d1); Value base = xferOp.indices()[dim.getValue()]; Value memrefIdx = makeComposedAffineApply(b, loc, d0 + d1, {base, iv}); - cond = lb.create(CmpIPredicate::sgt, memrefDim, memrefIdx); + cond = lb.create(arith::CmpIPredicate::sgt, memrefDim, + memrefIdx); } // Condition check 2: Masked in? if (auto maskCond = generateMaskCheck(b, xferOp, iv)) { if (cond) - cond = lb.create(cond, maskCond); + cond = lb.create(cond, maskCond); else cond = maskCond; } @@ -704,10 +706,10 @@ } // Loop bounds and step. - auto lb = locB.create(0); - auto ub = locB.create( + auto lb = locB.create(0); + auto ub = locB.create( castedDataType.getDimSize(castedDataType.getRank() - 1)); - auto step = locB.create(1); + auto step = locB.create(1); // TransferWriteOps that operate on tensors return the modified tensor and // require a loop state. auto loopState = Strategy::initialLoopState(xferOp); @@ -897,7 +899,7 @@ // Generate fully unrolled loop of transfer ops. Location loc = xferOp.getLoc(); for (int64_t i = 0; i < dimSize; ++i) { - Value iv = rewriter.create(loc, i); + Value iv = rewriter.create(loc, i); vec = generateInBoundsCheck( rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType), @@ -1023,7 +1025,7 @@ // Generate fully unrolled loop of transfer ops. Location loc = xferOp.getLoc(); for (int64_t i = 0; i < dimSize; ++i) { - Value iv = rewriter.create(loc, i); + Value iv = rewriter.create(loc, i); auto updatedSource = generateInBoundsCheck( rewriter, xferOp, iv, unpackedDim(xferOp), @@ -1114,8 +1116,8 @@ ValueRange loopState) { SmallVector indices; auto dim = get1dMemrefIndices(b, xferOp, iv, indices); - Value ivI32 = - b.create(loc, IntegerType::get(b.getContext(), 32), iv); + Value ivI32 = b.create( + loc, IntegerType::get(b.getContext(), 32), iv); auto vec = loopState[0]; // In case of out-of-bounds access, leave `vec` as is (was initialized with @@ -1147,8 +1149,8 @@ ValueRange /*loopState*/) { SmallVector indices; auto dim = get1dMemrefIndices(b, xferOp, iv, indices); - Value ivI32 = - b.create(loc, IntegerType::get(b.getContext(), 32), iv); + Value ivI32 = b.create( + loc, IntegerType::get(b.getContext(), 32), iv); // Nothing to do in case of out-of-bounds access. generateInBoundsCheck( @@ -1224,9 +1226,10 @@ // Loop bounds, step, state... Location loc = xferOp.getLoc(); auto vecType = xferOp.getVectorType(); - auto lb = rewriter.create(loc, 0); - auto ub = rewriter.create(loc, vecType.getDimSize(0)); - auto step = rewriter.create(loc, 1); + auto lb = rewriter.create(loc, 0); + auto ub = + rewriter.create(loc, vecType.getDimSize(0)); + auto step = rewriter.create(loc, 1); auto loopState = Strategy1d::initialLoopState(rewriter, xferOp); // Generate for loop. diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" @@ -221,7 +222,7 @@ Operation *AffineDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - return builder.create(loc, type, value); + return builder.create(loc, type, value); } /// A utility function to check if a value is defined at the top level of an @@ -1879,12 +1880,11 @@ buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn) { - auto lbConst = lb.getDefiningOp(); - auto ubConst = ub.getDefiningOp(); + auto lbConst = lb.getDefiningOp(); + auto ubConst = ub.getDefiningOp(); if (lbConst && ubConst) - return buildAffineLoopFromConstants(builder, loc, lbConst.getValue(), - ubConst.getValue(), step, - bodyBuilderFn); + return buildAffineLoopFromConstants(builder, loc, lbConst.value(), + ubConst.value(), step, bodyBuilderFn); return builder.create(loc, lb, builder.getDimIdentityMap(), ub, builder.getDimIdentityMap(), step, /*iterArgs=*/llvm::None, bodyBuilderFn); diff --git a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt @@ -11,6 +11,7 @@ MLIRAffineOpsIncGen LINK_LIBS PUBLIC + MLIRArithmetic MLIRIR MLIRLoopLikeInterface MLIRMemRef diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp @@ -23,6 +23,7 @@ #include "mlir/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -199,7 +200,7 @@ void AffineDataCopyGeneration::runOnFunction() { FuncOp f = getFunction(); OpBuilder topBuilder(f.getBody()); - zeroIndex = topBuilder.create(f.getLoc(), 0); + zeroIndex = topBuilder.create(f.getLoc(), 0); // Nests that are copy-in's or copy-out's; the root AffineForOps of those // nests are stored herein. diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp @@ -18,6 +18,7 @@ #include "mlir/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" @@ -81,7 +82,7 @@ } else if (isa(op)) { // TODO: Support DMA ops. return false; - } else if (!isa(op)) { + } else if (!isa(op)) { // Register op in the set of ops that have users. opsWithUsers.insert(&op); if (isa(op)) { diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt @@ -21,6 +21,7 @@ LINK_LIBS PUBLIC MLIRAffine MLIRAffineUtils + MLIRArithmetic MLIRIR MLIRMemRef MLIRPass diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -17,6 +17,7 @@ #include "mlir/Analysis/NestedMatcher.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/Vector/VectorUtils.h" #include "mlir/IR/BlockAndValueMapping.h" @@ -343,8 +344,8 @@ /// %A = alloc (%M, %N) : memref /// %B = alloc (%M, %N) : memref /// %C = alloc (%M, %N) : memref -/// %f1 = constant 1.0 : f32 -/// %f2 = constant 2.0 : f32 +/// %f1 = arith.const 1.0 : f32 +/// %f2 = arith.const 2.0 : f32 /// affine.for %i0 = 0 to %M { /// affine.for %i1 = 0 to %N { /// // non-scoped %f1 @@ -361,18 +362,18 @@ /// affine.for %i5 = 0 to %N { /// %a5 = affine.load %A[%i4, %i5] : memref /// %b5 = affine.load %B[%i4, %i5] : memref -/// %s5 = addf %a5, %b5 : f32 +/// %s5 = arith.addf %a5, %b5 : f32 /// // non-scoped %f1 -/// %s6 = addf %s5, %f1 : f32 +/// %s6 = arith.addf %s5, %f1 : f32 /// // non-scoped %f2 -/// %s7 = addf %s5, %f2 : f32 +/// %s7 = arith.addf %s5, %f2 : f32 /// // diamond dependency. -/// %s8 = addf %s7, %s6 : f32 +/// %s8 = arith.addf %s7, %s6 : f32 /// affine.store %s8, %C[%i4, %i5] : memref /// } /// } -/// %c7 = constant 7 : index -/// %c42 = constant 42 : index +/// %c7 = arith.const 7 : index +/// %c42 = arith.const 42 : index /// %res = load %C[%c7, %c42] : memref /// return %res : f32 /// } @@ -389,11 +390,11 @@ /// %0 = alloc(%arg0, %arg1) : memref /// %1 = alloc(%arg0, %arg1) : memref /// %2 = alloc(%arg0, %arg1) : memref -/// %cst = constant 1.0 : f32 -/// %cst_0 = constant 2.0 : f32 +/// %cst = arith.const 1.0 : f32 +/// %cst_0 = arith.const 2.0 : f32 /// affine.for %i0 = 0 to %arg0 { /// affine.for %i1 = 0 to %arg1 step 256 { -/// %cst_1 = constant dense, 1.0> : +/// %cst_1 = arith.const dense, 1.0> : /// vector<256xf32> /// vector.transfer_write %cst_1, %0[%i0, %i1] : /// vector<256xf32>, memref @@ -401,7 +402,7 @@ /// } /// affine.for %i2 = 0 to %arg0 { /// affine.for %i3 = 0 to %arg1 step 256 { -/// %cst_2 = constant dense, 2.0> : +/// %cst_2 = arith.const dense, 2.0> : /// vector<256xf32> /// vector.transfer_write %cst_2, %1[%i2, %i3] : /// vector<256xf32>, memref @@ -413,20 +414,20 @@ /// memref, vector<256xf32> /// %4 = vector.transfer_read %1[%i4, %i5] : /// memref, vector<256xf32> -/// %5 = addf %3, %4 : vector<256xf32> -/// %cst_3 = constant dense, 1.0> : +/// %5 = arith.addf %3, %4 : vector<256xf32> +/// %cst_3 = arith.const dense, 1.0> : /// vector<256xf32> -/// %6 = addf %5, %cst_3 : vector<256xf32> -/// %cst_4 = constant dense, 2.0> : +/// %6 = arith.addf %5, %cst_3 : vector<256xf32> +/// %cst_4 = arith.const dense, 2.0> : /// vector<256xf32> -/// %7 = addf %5, %cst_4 : vector<256xf32> -/// %8 = addf %7, %6 : vector<256xf32> +/// %7 = arith.addf %5, %cst_4 : vector<256xf32> +/// %8 = arith.addf %7, %6 : vector<256xf32> /// vector.transfer_write %8, %2[%i4, %i5] : /// vector<256xf32>, memref /// } /// } -/// %c7 = constant 7 : index -/// %c42 = constant 42 : index +/// %c7 = arith.const 7 : index +/// %c42 = arith.const 42 : index /// %9 = load %2[%c7, %c42] : memref /// return %9 : f32 /// } @@ -443,11 +444,11 @@ /// %0 = alloc(%arg0, %arg1) : memref /// %1 = alloc(%arg0, %arg1) : memref /// %2 = alloc(%arg0, %arg1) : memref -/// %cst = constant 1.0 : f32 -/// %cst_0 = constant 2.0 : f32 +/// %cst = arith.const 1.0 : f32 +/// %cst_0 = arith.const 2.0 : f32 /// affine.for %i0 = 0 to %arg0 step 32 { /// affine.for %i1 = 0 to %arg1 step 256 { -/// %cst_1 = constant dense, 1.0> : +/// %cst_1 = arith.const dense, 1.0> : /// vector<32x256xf32> /// vector.transfer_write %cst_1, %0[%i0, %i1] : /// vector<32x256xf32>, memref @@ -455,7 +456,7 @@ /// } /// affine.for %i2 = 0 to %arg0 step 32 { /// affine.for %i3 = 0 to %arg1 step 256 { -/// %cst_2 = constant dense, 2.0> : +/// %cst_2 = arith.const dense, 2.0> : /// vector<32x256xf32> /// vector.transfer_write %cst_2, %1[%i2, %i3] : /// vector<32x256xf32>, memref @@ -467,20 +468,20 @@ /// memref vector<32x256xf32> /// %4 = vector.transfer_read %1[%i4, %i5] : /// memref, vector<32x256xf32> -/// %5 = addf %3, %4 : vector<32x256xf32> -/// %cst_3 = constant dense, 1.0> : +/// %5 = arith.addf %3, %4 : vector<32x256xf32> +/// %cst_3 = arith.const dense, 1.0> : /// vector<32x256xf32> -/// %6 = addf %5, %cst_3 : vector<32x256xf32> -/// %cst_4 = constant dense, 2.0> : +/// %6 = arith.addf %5, %cst_3 : vector<32x256xf32> +/// %cst_4 = arith.const dense, 2.0> : /// vector<32x256xf32> -/// %7 = addf %5, %cst_4 : vector<32x256xf32> -/// %8 = addf %7, %6 : vector<32x256xf32> +/// %7 = arith.addf %5, %cst_4 : vector<32x256xf32> +/// %8 = arith.addf %7, %6 : vector<32x256xf32> /// vector.transfer_write %8, %2[%i4, %i5] : /// vector<32x256xf32>, memref /// } /// } -/// %c7 = constant 7 : index -/// %c42 = constant 42 : index +/// %c7 = arith.const 7 : index +/// %c42 = arith.const 42 : index /// %9 = load %2[%c7, %c42] : memref /// return %9 : f32 /// } @@ -510,11 +511,11 @@ /// Consider the following example: /// ```mlir /// func @vecred(%in: memref<512xf32>) -> f32 { -/// %cst = constant 0.000000e+00 : f32 +/// %cst = arith.const 0.000000e+00 : f32 /// %sum = affine.for %i = 0 to 500 iter_args(%part_sum = %cst) -> (f32) { /// %ld = affine.load %in[%i] : memref<512xf32> /// %cos = math.cos %ld : f32 -/// %add = addf %part_sum, %cos : f32 +/// %add = arith.addf %part_sum, %cos : f32 /// affine.yield %add : f32 /// } /// return %sum : f32 @@ -530,18 +531,18 @@ /// ```mlir /// #map = affine_map<(d0) -> (-d0 + 500)> /// func @vecred(%arg0: memref<512xf32>) -> f32 { -/// %cst = constant 0.000000e+00 : f32 -/// %cst_0 = constant dense<0.000000e+00> : vector<128xf32> +/// %cst = arith.const 0.000000e+00 : f32 +/// %cst_0 = arith.const dense<0.000000e+00> : vector<128xf32> /// %0 = affine.for %arg1 = 0 to 500 step 128 iter_args(%arg2 = %cst_0) /// -> (vector<128xf32>) { /// // %2 is the number of iterations left in the original loop. /// %2 = affine.apply #map(%arg1) /// %3 = vector.create_mask %2 : vector<128xi1> -/// %cst_1 = constant 0.000000e+00 : f32 +/// %cst_1 = arith.const 0.000000e+00 : f32 /// %4 = vector.transfer_read %arg0[%arg1], %cst_1 : /// memref<512xf32>, vector<128xf32> /// %5 = math.cos %4 : vector<128xf32> -/// %6 = addf %arg2, %5 : vector<128xf32> +/// %6 = arith.addf %arg2, %5 : vector<128xf32> /// // We filter out the effect of last 12 elements using the mask. /// %7 = select %3, %6, %arg2 : vector<128xi1>, vector<128xf32> /// affine.yield %7 : vector<128xf32> @@ -673,8 +674,8 @@ /// the vectorized operations. /// /// Example: - /// * 'replaced': %0 = addf %1, %2 : f32 - /// * 'replacement': %0 = addf %1, %2 : vector<128xf32> + /// * 'replaced': %0 = arith.addf %1, %2 : f32 + /// * 'replacement': %0 = arith.addf %1, %2 : vector<128xf32> void registerOpVectorReplacement(Operation *replaced, Operation *replacement); /// Registers the vector replacement of a scalar value. The replacement @@ -771,8 +772,8 @@ /// the vectorized operations. /// /// Example: -/// * 'replaced': %0 = addf %1, %2 : f32 -/// * 'replacement': %0 = addf %1, %2 : vector<128xf32> +/// * 'replaced': %0 = arith.addf %1, %2 : f32 +/// * 'replacement': %0 = arith.addf %1, %2 : vector<128xf32> void VectorizationState::registerOpVectorReplacement(Operation *replaced, Operation *replacement) { LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ commit vectorized op:\n"); @@ -940,14 +941,14 @@ /// Tries to transform a scalar constant into a vector constant. Returns the /// vector constant if the scalar type is valid vector element type. Returns /// nullptr, otherwise. -static ConstantOp vectorizeConstant(ConstantOp constOp, - VectorizationState &state) { +static arith::ConstantOp vectorizeConstant(arith::ConstantOp constOp, + VectorizationState &state) { Type scalarTy = constOp.getType(); if (!VectorType::isValidElementType(scalarTy)) return nullptr; auto vecTy = getVectorType(scalarTy, state.strategy); - auto vecAttr = DenseElementsAttr::get(vecTy, constOp.getValue()); + auto vecAttr = DenseElementsAttr::get(vecTy, constOp.value()); OpBuilder::InsertionGuard guard(state.builder); Operation *parentOp = state.builder.getInsertionBlock()->getParentOp(); @@ -958,7 +959,8 @@ isa(parentOp) && "Expected a vectorized for op"); auto vecForOp = cast(parentOp); state.builder.setInsertionPointToStart(vecForOp.getBody()); - auto newConstOp = state.builder.create(constOp.getLoc(), vecAttr); + auto newConstOp = + state.builder.create(constOp.getLoc(), vecAttr); // Register vector replacement for future uses in the scope. state.registerOpVectorReplacement(constOp, newConstOp); @@ -968,9 +970,9 @@ /// Creates a constant vector filled with the neutral elements of the given /// reduction. The scalar type of vector elements will be taken from /// `oldOperand`. -static ConstantOp createInitialVector(AtomicRMWKind reductionKind, - Value oldOperand, - VectorizationState &state) { +static arith::ConstantOp createInitialVector(AtomicRMWKind reductionKind, + Value oldOperand, + VectorizationState &state) { Type scalarTy = oldOperand.getType(); if (!VectorType::isValidElementType(scalarTy)) return nullptr; @@ -980,7 +982,7 @@ auto vecTy = getVectorType(scalarTy, state.strategy); auto vecAttr = DenseElementsAttr::get(vecTy, valueAttr); auto newConstOp = - state.builder.create(oldOperand.getLoc(), vecAttr); + state.builder.create(oldOperand.getLoc(), vecAttr); return newConstOp; } @@ -1120,8 +1122,8 @@ "Vector op not found in replacement map"); // Vectorize constant. - if (auto constOp = operand.getDefiningOp()) { - ConstantOp vecConstant = vectorizeConstant(constOp, state); + if (auto constOp = operand.getDefiningOp()) { + auto vecConstant = vectorizeConstant(constOp, state); LLVM_DEBUG(dbgs() << "-> constant: " << vecConstant); return vecConstant.getResult(); } @@ -1242,7 +1244,7 @@ return false; Attribute valueAttr = getIdentityValueAttr(reductionKind, scalarTy, state.builder, value.getLoc()); - if (auto constOp = dyn_cast_or_null(value.getDefiningOp())) + if (auto constOp = dyn_cast_or_null(value.getDefiningOp())) return constOp.value() == valueAttr; return false; } @@ -1417,7 +1419,7 @@ // being added to the accumulator by inserting `select` operations, for // example: // - // %res = addf %acc, %val : vector<128xf32> + // %res = arith.addf %acc, %val : vector<128xf32> // %res_masked = select %mask, %res, %acc : vector<128xi1>, vector<128xf32> // affine.yield %res_masked : vector<128xf32> // @@ -1464,7 +1466,7 @@ return vectorizeAffineForOp(forOp, state); if (auto yieldOp = dyn_cast(op)) return vectorizeAffineYieldOp(yieldOp, state); - if (auto constant = dyn_cast(op)) + if (auto constant = dyn_cast(op)) return vectorizeConstant(constant, state); // Other ops with regions are not supported. diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticDialect.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticDialect.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticDialect.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticDialect.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/IR/Builders.h" #include "mlir/Transforms/InliningUtils.h" using namespace mlir; @@ -28,10 +29,22 @@ }; } // end anonymous namespace -void mlir::arith::ArithmeticDialect::initialize() { +namespace mlir { + +void arith::ArithmeticDialect::initialize() { addOperations< #define GET_OP_LIST #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc" >(); addInterfaces(); } + +/// Materialize an integer or floating point constant. +Operation *arith::ArithmeticDialect::materializeConstant(OpBuilder &builder, + Attribute value, + Type type, + Location loc) { + return builder.create(loc, value, type); +} + +} // end namespace mlir diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -75,6 +75,92 @@ #include "ArithmeticCanonicalization.inc" } // end anonymous namespace +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + +void arith::ConstantOp::getAsmResultNames( + function_ref setNameFn) { + auto type = getType(); + if (auto intCst = value().dyn_cast()) { + auto intType = type.dyn_cast(); + + // Sugar i1 constants with 'true' and 'false'. + if (intType && intType.getWidth() == 1) + return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); + + // Otherwise, build a compex name with the value and type. + SmallString<32> specialNameBuffer; + llvm::raw_svector_ostream specialName(specialNameBuffer); + specialName << 'c' << intCst.getInt(); + if (intType) + specialName << '_' << type; + setNameFn(getResult(), specialName.str()); + } else { + setNameFn(getResult(), "cst"); + } +} + +bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { + // The value's type must be the same as the provided type. + if (value.getType() != type) + return false; + // Integers values must be signless. + if (type.isa() && !type.cast().isSignless()) + return false; + // Integer, float, and element attributes are buildable. + return value.isa(); +} + +OpFoldResult arith::ConstantOp::fold(ArrayRef operands) { + return value(); +} + +void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, + int64_t value, unsigned width) { + auto type = builder.getIntegerType(width); + arith::ConstantOp::build(builder, result, type, + builder.getIntegerAttr(type, value)); +} + +void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, + int64_t value, Type type) { + assert(type.isSignlessInteger() && + "ConstantIntOp can only have signless integer type values"); + arith::ConstantOp::build(builder, result, type, + builder.getIntegerAttr(type, value)); +} + +bool arith::ConstantIntOp::classof(Operation *op) { + if (auto constOp = dyn_cast_or_null(op)) + return constOp.getType().isSignlessInteger(); + return false; +} + +void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result, + const APFloat &value, FloatType type) { + arith::ConstantOp::build(builder, result, type, + builder.getFloatAttr(type, value)); +} + +bool arith::ConstantFloatOp::classof(Operation *op) { + if (auto constOp = dyn_cast_or_null(op)) + return constOp.getType().isa(); + return false; +} + +void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result, + int64_t value) { + arith::ConstantOp::build(builder, result, builder.getIndexType(), + builder.getIndexAttr(value)); +} + +bool arith::ConstantIndexOp::classof(Operation *op) { + if (auto constOp = dyn_cast_or_null(op)) + return constOp.getType().isIndex(); + return false; +} + //===----------------------------------------------------------------------===// // AddIOp //===----------------------------------------------------------------------===// @@ -439,6 +525,51 @@ operands, [](APFloat a, APFloat b) { return a / b; }); } +//===----------------------------------------------------------------------===// +// Utility functions for verifying cast ops +//===----------------------------------------------------------------------===// + +namespace { +template +struct type_list {}; +} // end anonymous namespace + +/// Returns a non-null type only if the provided type is one of the allowed +/// types or one of the allowed shaped types of the allowed types. Returns the +/// element type if a valid shaped type is provided. +template +static Type getUnderlyingType(Type type, type_list, + type_list) { + if (type.isa() && !type.isa()) + return {}; + + auto underlyingType = getElementTypeOrSelf(type); + if (!underlyingType.isa()) + return {}; + + return underlyingType; +} + +/// Get allowed underlying types for vectors and tensors. +template +static Type getTypeIfLike(Type type) { + return getUnderlyingType(type, type_list(), + type_list()); +} + +/// Get allowed underlying types for vectors, tensors, and memrefs. +template +static Type getTypeIfLikeOrMemRef(Type type) { + return getUnderlyingType(type, + type_list(), + type_list()); +} + +static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) { + return inputs.size() == 1 && outputs.size() == 1 && + succeeded(verifyCompatibleShapes(inputs.front(), outputs.front())); +} + //===----------------------------------------------------------------------===// // Verifiers for integer and floating point extension/truncation ops //===----------------------------------------------------------------------===// @@ -469,6 +600,21 @@ return success(); } +/// Validate a cast that changes the width of a type. +template