diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td @@ -31,6 +31,7 @@ def ArithExpandOps : Pass<"arith-expand"> { let summary = "Legalize Arith ops to be convertible to LLVM."; let constructor = "mlir::arith::createArithExpandOpsPass()"; + let dependentDialects = ["vector::VectorDialect"]; } def ArithUnsignedWhenEquivalent : Pass<"arith-unsigned-when-equivalent"> { diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp --- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" @@ -24,8 +25,21 @@ /// Create an integer or index constant. static Value createConst(Location loc, Type type, int value, PatternRewriter &rewriter) { - return rewriter.create( - loc, rewriter.getIntegerAttr(type, value)); + Type elTy = type; + ShapedType vecTy = llvm::dyn_cast_or_null(type); + if (vecTy) + elTy = vecTy.getElementType(); + + auto constantOp = rewriter.create( + loc, rewriter.getIntegerAttr(elTy, value)); + + if (!vecTy) + return constantOp; + + auto broadCastOp = rewriter.create( + loc, VectorType::get(vecTy.getShape(), elTy), constantOp); + + return broadCastOp; } namespace { diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir --- a/mlir/test/Dialect/Arith/expand-ops.mlir +++ b/mlir/test/Dialect/Arith/expand-ops.mlir @@ -114,6 +114,39 @@ // ----- +// Test floor divide with vector +// CHECK-LABEL: func.func @floordivi_vec( +// CHECK-SAME: %[[VAL_0:.*]]: vector<4xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: vector<4xi32>) -> vector<4xi32> { +func.func @floordivi_vec(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> (vector<4xi32>) { + %res = arith.floordivsi %arg0, %arg1 : vector<4xi32> + return %res : vector<4xi32> +// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_3:.*]] = arith.constant dense<1> : vector<4xi32> +// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_5:.*]] = arith.constant dense<0> : vector<4xi32> +// CHECK: %[[VAL_6:.*]] = arith.constant -1 : i32 +// CHECK: %[[VAL_7:.*]] = arith.constant dense<-1> : vector<4xi32> +// CHECK: %[[VAL_8:.*]] = arith.cmpi slt, %[[VAL_1]], %[[VAL_5]] : vector<4xi32> +// CHECK: %[[VAL_9:.*]] = arith.select %[[VAL_8]], %[[VAL_3]], %[[VAL_7]] : vector<4xi1>, vector<4xi32> +// CHECK: %[[VAL_10:.*]] = arith.subi %[[VAL_9]], %[[VAL_0]] : vector<4xi32> +// CHECK: %[[VAL_11:.*]] = arith.divsi %[[VAL_10]], %[[VAL_1]] : vector<4xi32> +// CHECK: %[[VAL_12:.*]] = arith.subi %[[VAL_7]], %[[VAL_11]] : vector<4xi32> +// CHECK: %[[VAL_13:.*]] = arith.divsi %[[VAL_0]], %[[VAL_1]] : vector<4xi32> +// CHECK: %[[VAL_14:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_5]] : vector<4xi32> +// CHECK: %[[VAL_15:.*]] = arith.cmpi sgt, %[[VAL_0]], %[[VAL_5]] : vector<4xi32> +// CHECK: %[[VAL_16:.*]] = arith.cmpi slt, %[[VAL_1]], %[[VAL_5]] : vector<4xi32> +// CHECK: %[[VAL_17:.*]] = arith.cmpi sgt, %[[VAL_1]], %[[VAL_5]] : vector<4xi32> +// CHECK: %[[VAL_18:.*]] = arith.andi %[[VAL_14]], %[[VAL_17]] : vector<4xi1> +// CHECK: %[[VAL_19:.*]] = arith.andi %[[VAL_15]], %[[VAL_16]] : vector<4xi1> +// CHECK: %[[VAL_20:.*]] = arith.ori %[[VAL_18]], %[[VAL_19]] : vector<4xi1> +// CHECK: %[[VAL_21:.*]] = arith.select %[[VAL_20]], %[[VAL_12]], %[[VAL_13]] : vector<4xi1>, vector<4xi32> +// CHECK: return %[[VAL_21]] : vector<4xi32> +// CHECK: } +} + +// ----- + // Test ceil divide with unsigned integer // CHECK-LABEL: func @ceildivui // CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 {