diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -69,7 +69,9 @@ // Base class for arithmetic cast operations. class ArithmeticCastOp traits = []> : - CastOp { + CastOp])> { } // Base class for unary ops. Requires single operand and result. Individual @@ -104,6 +106,7 @@ Op, ElementwiseMappable])> { let results = (outs AnyType:$result); @@ -992,6 +995,7 @@ def CmpFOp : Std_Op<"cmpf", [NoSideEffect, SameTypeOperands, ElementwiseMappable, + DeclareOpInterfaceMethods, TypesMatchWith< "result type has i1 element type and same shape as operands", "lhs", "result", "getI1SameShape($_self)">]> { @@ -1076,6 +1080,7 @@ def CmpIOp : Std_Op<"cmpi", [NoSideEffect, SameTypeOperands, ElementwiseMappable, + DeclareOpInterfaceMethods, TypesMatchWith< "result type has i1 element type and same shape as operands", "lhs", "result", "getI1SameShape($_self)">]> { @@ -2548,7 +2553,7 @@ def SelectOp : Std_Op<"select", [NoSideEffect, AllTypesMatch<["true_value", "false_value", "result"]>, - ElementwiseMappable]> { + ElementwiseMappable, DeclareOpInterfaceMethods]> { let summary = "select operation"; let description = [{ The `select` operation chooses one value based on a binary condition @@ -2779,7 +2784,8 @@ //===----------------------------------------------------------------------===// def SignExtendIOp : Std_Op<"sexti", - [NoSideEffect, ElementwiseMappable]> { + [NoSideEffect, ElementwiseMappable, + DeclareOpInterfaceMethods]> { let summary = "integer sign extension operation"; let description = [{ The integer sign extension operation takes an integer input of @@ -3595,7 +3601,9 @@ // TruncateIOp //===----------------------------------------------------------------------===// -def TruncateIOp : Std_Op<"trunci", [NoSideEffect, ElementwiseMappable]> { +def TruncateIOp : Std_Op<"trunci", + [NoSideEffect, ElementwiseMappable, + DeclareOpInterfaceMethods,]> { let summary = "integer truncation operation"; let description = [{ The integer truncation operation takes an integer input of @@ -3862,7 +3870,9 @@ // ZeroExtendIOp //===----------------------------------------------------------------------===// -def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect, ElementwiseMappable]> { +def ZeroExtendIOp : Std_Op<"zexti", + [NoSideEffect, ElementwiseMappable, + DeclareOpInterfaceMethods,]> { let summary = "integer zero extension operation"; let description = [{ The integer zero extension operation takes an integer input of diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -492,8 +492,9 @@ assert(resultType && "Expected op with vector result type"); auto resultShape = resultType.getShape(); // Verify that all operands have the same vector type as result. - assert(llvm::all_of(op->getOperandTypes(), - [=](Type type) { return type == resultType; })); + assert(llvm::all_of(op->getOperandTypes(), [=](Type type) { + return type.cast().getShape() == resultShape; + })); // Create trivial elementwise identity index map based on 'resultShape'. DenseMap indexMap; @@ -504,8 +505,9 @@ // Create VectorState each operand and single result. unsigned numVectors = op->getNumOperands() + op->getNumResults(); vectors.resize(numVectors); - for (unsigned i = 0; i < op->getNumOperands(); ++i) - vectors[i] = {resultType, indexMap, i, false}; + for (auto it : llvm::enumerate(op->getOperandTypes())) + vectors[it.index()] = {it.value().cast(), indexMap, + static_cast(it.index()), false}; vectors[numVectors - 1] = {resultType, indexMap, -1, false}; resultIndex = numVectors - 1; } diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -1,5 +1,4 @@ // RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s -// RUN: mlir-opt %s -test-vector-unrolling-patterns | FileCheck %s // CHECK-DAG: #[[MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)> @@ -514,3 +513,38 @@ return %6, %7 : vector<5x4x2xf32>, vector<3x4x2xf32> } + +// CHECK-LABEL: func @elementwise_unroll +// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32>, %[[ARG1:.*]]: memref<4x4xf32>) +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C2:.*]] = constant 2 : index +// CHECK: %[[VT0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK: %[[VT1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK: %[[VT2:.*]] = vector.transfer_read %[[ARG0]][%[[C2]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK: %[[VT3:.*]] = vector.transfer_read %[[ARG0]][%[[C2]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK: %[[VT4:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK: %[[VT5:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK: %[[VT6:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK: %[[VT7:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK: %[[CMP0:.*]] = cmpf "ult", %[[VT0]], %[[VT4]] : vector<2x2xf32> +// CHECK: %[[CMP1:.*]] = cmpf "ult", %[[VT1]], %[[VT5]] : vector<2x2xf32> +// CHECK: %[[CMP2:.*]] = cmpf "ult", %[[VT2]], %[[VT6]] : vector<2x2xf32> +// CHECK: %[[CMP3:.*]] = cmpf "ult", %[[VT3]], %[[VT7]] : vector<2x2xf32> +// CHECK: %[[SEL0:.*]] = select %[[CMP0]], %[[VT0]], %[[VT4]] : vector<2x2xi1>, vector<2x2xf32> +// CHECK: %[[SEL1:.*]] = select %[[CMP1]], %[[VT1]], %[[VT5]] : vector<2x2xi1>, vector<2x2xf32> +// CHECK: %[[SEL2:.*]] = select %[[CMP2]], %[[VT2]], %[[VT6]] : vector<2x2xi1>, vector<2x2xf32> +// CHECK: %[[SEL3:.*]] = select %[[CMP3]], %[[VT3]], %[[VT7]] : vector<2x2xi1>, vector<2x2xf32> +// CHECK: vector.transfer_write %[[SEL0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> +// CHECK: vector.transfer_write %[[SEL1]], %[[ARG0]][%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> +// CHECK: vector.transfer_write %[[SEL2]], %[[ARG0]][%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> +// CHECK: vector.transfer_write %[[SEL3]], %[[ARG0]][%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> +func @elementwise_unroll(%arg0 : memref<4x4xf32>, %arg1 : memref<4x4xf32>) { + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32> + %1 = vector.transfer_read %arg1[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32> + %cond = cmpf "ult", %0, %1 : vector<4x4xf32> + %2 = select %cond, %0, %1 : vector<4x4xi1>, vector<4x4xf32> + vector.transfer_write %2, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32> + return +} diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -37,8 +37,11 @@ private: // Return the target shape based on op type. static Optional> getShape(Operation *op) { - if (isa(op)) + if (isa(op)) return SmallVector(2, 2); + if (auto transferOp = dyn_cast(op)) { + return SmallVector(transferOp.getVectorType().getRank(), 2); + } if (isa(op)) return SmallVector(3, 2); return llvm::None;