diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -817,59 +817,6 @@ } }; -// TODO(ajcbik): remove this rule once LinAlg tests are cleaned up -class VectorOuterProductOpConversion : public ConvertToLLVMPattern { -public: - explicit VectorOuterProductOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(vector::OuterProductOp::getOperationName(), - context, typeConverter) {} - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - auto adaptor = vector::OuterProductOpOperandAdaptor(operands); - auto *ctx = op->getContext(); - auto vLHS = adaptor.lhs().getType().cast(); - auto vRHS = adaptor.rhs().getType().cast(); - auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements(); - auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements(); - auto llvmArrayOfVectType = typeConverter.convertType( - cast(op).getResult().getType()); - Value desc = rewriter.create(loc, llvmArrayOfVectType); - Value a = adaptor.lhs(), b = adaptor.rhs(); - Value acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front(); - SmallVector lhs, accs; - lhs.reserve(rankLHS); - accs.reserve(rankLHS); - for (unsigned d = 0, e = rankLHS; d < e; ++d) { - // shufflevector explicitly requires i32. - auto attr = rewriter.getI32IntegerAttr(d); - SmallVector bcastAttr(rankRHS, attr); - auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx); - Value aD = nullptr, accD = nullptr; - // 1. Broadcast the element a[d] into vector aD. - aD = rewriter.create(loc, a, a, bcastArrayAttr); - // 2. If acc is present, extract 1-d vector acc[d] into accD. - if (acc) - accD = rewriter.create( - loc, vRHS, acc, rewriter.getI64ArrayAttr(d)); - // 3. Compute aD outer b (plus accD, if relevant). - Value aOuterbD = - accD - ? rewriter.create(loc, vRHS, aD, b, accD).getResult() - : rewriter.create(loc, aD, b).getResult(); - // 4. Insert as value `d` in the descriptor. - desc = rewriter.create(loc, llvmArrayOfVectType, - desc, aOuterbD, - rewriter.getI64ArrayAttr(d)); - } - rewriter.replaceOp(op, desc); - return matchSuccess(); - } -}; - class VectorTypeCastOpConversion : public ConvertToLLVMPattern { public: explicit VectorTypeCastOpConversion(MLIRContext *context, @@ -1160,8 +1107,8 @@ VectorShuffleOpConversion, VectorExtractElementOpConversion, VectorExtractOpConversion, VectorFMAOp1DConversion, VectorInsertElementOpConversion, VectorInsertOpConversion, - VectorOuterProductOpConversion, VectorTypeCastOpConversion, - VectorPrintOpConversion>(ctx, converter); + VectorTypeCastOpConversion, VectorPrintOpConversion>( + ctx, converter); } void mlir::populateVectorToLLVMMatrixConversionPatterns( diff --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir --- a/mlir/test/Dialect/Linalg/llvm.mlir +++ b/mlir/test/Dialect/Linalg/llvm.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s -convert-linalg-to-llvm | FileCheck %s -// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm | FileCheck %s --check-prefix=LLVM-LOOPS +// RUN: mlir-opt %s -convert-linalg-to-loops | FileCheck %s --check-prefix=LLVM-LOOPS func @range(%arg0: index) { %c0 = constant 0 : index @@ -172,14 +172,22 @@ // CHECK-SAME: !llvm<"[4 x <4 x float>]*">, !llvm<"[4 x <4 x float>]*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64 // LLVM-LOOPS-LABEL: func @matmul_vec_impl( -// LLVM-LOOPS: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>"> -// LLVM-LOOPS: llvm.shufflevector {{.*}} [1 : i32, 1 : i32, 1 : i32, 1 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>"> -// LLVM-LOOPS: llvm.shufflevector {{.*}} [2 : i32, 2 : i32, 2 : i32, 2 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>"> -// LLVM-LOOPS: llvm.shufflevector {{.*}} [3 : i32, 3 : i32, 3 : i32, 3 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>"> -// LLVM-LOOPS-NEXT: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x <4 x float>]"> -// LLVM-LOOPS-NEXT: "llvm.intr.fma"({{.*}}) : (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>"> -// LLVM-LOOPS-NEXT: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x <4 x float>]"> - +// LLVM-LOOPS-SAME: %[[A:.*0]]: memref>, +// LLVM-LOOPS-SAME: %[[B:.*1]]: memref>, +// LLVM-LOOPS-SAME: %[[C:.*2]]: memref>) +// LLVM-LOOPS: %[[C0:.*]] = constant 0 : index +// LLVM-LOOPS: %[[C1:.*]] = constant 1 : index +// LLVM-LOOPS: %[[T0:.*]] = dim %[[A]], 0 : memref> +// LLVM-LOOPS: %[[T1:.*]] = dim %[[A]], 1 : memref> +// LLVM-LOOPS: %[[T2:.*]] = dim %[[B]], 1 : memref> +// LLVM-LOOPS: loop.for %[[I:.*]] = %[[C0]] to %[[T0]] step %[[C1]] { +// LLVM-LOOPS: loop.for %[[J:.*]] = %[[C0]] to %[[T2]] step %[[C1]] { +// LLVM-LOOPS: loop.for %[[K:.*]] = %[[C0]] to %[[T1]] step %[[C1]] { +// LLVM-LOOPS: %[[T3:.*]] = load %[[A]][%[[I]], %[[K]]] : memref> +// LLVM-LOOPS: %[[T4:.*]] = load %[[B]][%[[K]], %[[J]]] : memref> +// LLVM-LOOPS: %[[T5:.*]] = load %[[C]][%[[I]], %[[J]]] : memref> +// LLVM-LOOPS: %[[T6:.*]] = vector.outerproduct %3, %4, %5 : vector<4xf32>, vector<4xf32> +// LLVM-LOOPS: store %[[T6]], %[[C]][%[[I]], %[[J]]] : memref> #indexed_matmul_trait = { args_in = 2,