diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -656,9 +656,6 @@ static_assert( std::is_base_of, SourceOp>::value, "expected single result op"); - static_assert(std::is_base_of, - SourceOp>::value, - "expected same operands and result type"); return LLVM::detail::vectorOneToOneRewrite( op, TargetOp::getOperationName(), operands, *this->getTypeConverter(), rewriter); diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1472,10 +1472,10 @@ // 1-D LLVM vectors. struct NDVectorTypeInfo { // LLVM array struct which encodes n-D vectors. - Type llvmArrayTy; + Type llvmNDVectorTy; // LLVM vector type which encodes the inner 1-D vector type. - Type llvmVectorTy; - // Multiplicity of llvmArrayTy to llvmVectorTy. + Type llvm1DVectorTy; + // Multiplicity of llvmNDVectorTy to llvm1DVectorTy. SmallVector arraySizes; }; } // namespace @@ -1488,13 +1488,13 @@ LLVMTypeConverter &converter) { assert(vectorType.getRank() > 1 && "expected >1D vector type"); NDVectorTypeInfo info; - info.llvmArrayTy = converter.convertType(vectorType); - if (!info.llvmArrayTy || !LLVM::isCompatibleType(info.llvmArrayTy)) { - info.llvmArrayTy = nullptr; + info.llvmNDVectorTy = converter.convertType(vectorType); + if (!info.llvmNDVectorTy || !LLVM::isCompatibleType(info.llvmNDVectorTy)) { + info.llvmNDVectorTy = nullptr; return info; } info.arraySizes.reserve(vectorType.getRank() - 1); - auto llvmTy = info.llvmArrayTy; + auto llvmTy = info.llvmNDVectorTy; while (llvmTy.isa()) { info.arraySizes.push_back( llvmTy.cast().getNumElements()); @@ -1502,7 +1502,7 @@ } if (!LLVM::isCompatibleVectorType(llvmTy)) return info; - info.llvmVectorTy = llvmTy; + info.llvm1DVectorTy = llvmTy; return info; } @@ -1591,27 +1591,29 @@ Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter, std::function createOperand, ConversionPatternRewriter &rewriter) { - auto vectorType = op->getResult(0).getType().dyn_cast(); - if (!vectorType) - return failure(); - auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, typeConverter); - auto llvmVectorTy = vectorTypeInfo.llvmVectorTy; - auto llvmArrayTy = operands[0].getType(); - if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy) - return failure(); - + auto operandNDVectorType = op->getOperand(0).getType().dyn_cast(); + auto resultNDVectorType = op->getResult(0).getType().dyn_cast(); + assert(operandNDVectorType && resultNDVectorType && "expected vector types"); + + auto resultTypeInfo = + extractNDVectorTypeInfo(resultNDVectorType, typeConverter); + auto operandTypeInfo = + extractNDVectorTypeInfo(operandNDVectorType, typeConverter); + auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy; + auto operand1DVectorTy = operandTypeInfo.llvm1DVectorTy; + auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy; auto loc = op->getLoc(); - Value desc = rewriter.create(loc, llvmArrayTy); - nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) { + Value desc = rewriter.create(loc, resultNDVectoryTy); + nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayAttr position) { // For this unrolled `position` corresponding to the `linearIndex`^th // element, extract operand vectors SmallVector extractedOperands; for (auto operand : operands) extractedOperands.push_back(rewriter.create( - loc, llvmVectorTy, operand, position)); - Value newVal = createOperand(llvmVectorTy, extractedOperands); - desc = rewriter.create(loc, llvmArrayTy, desc, newVal, - position); + loc, operand1DVectorTy, operand, position)); + Value newVal = createOperand(result1DVectorTy, extractedOperands); + desc = rewriter.create(loc, resultNDVectoryTy, desc, + newVal, position); }); rewriter.replaceOp(op, desc); return success(); @@ -1627,14 +1629,14 @@ [](Type t) { return isCompatibleType(t); })) return failure(); - auto llvmArrayTy = operands[0].getType(); - if (!llvmArrayTy.isa()) + auto llvmNDVectorTy = operands[0].getType(); + if (!llvmNDVectorTy.isa()) return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter); - auto callback = [op, targetOp, &rewriter](Type llvmVectorTy, + auto callback = [op, targetOp, &rewriter](Type llvm1DVectorTy, ValueRange operands) { OperationState state(op->getLoc(), targetOp); - state.addTypes(llvmVectorTy); + state.addTypes(llvm1DVectorTy); state.addOperands(operands); state.addAttributes(op->getAttrs()); return rewriter.createOperation(state)->getResult(0); @@ -1668,6 +1670,8 @@ using PowFOpLowering = VectorConvertToLLVMPattern; using RemFOpLowering = VectorConvertToLLVMPattern; using SelectOpLowering = OneToOneConvertToLLVMPattern; +using SignExtendIOpLowering = + VectorConvertToLLVMPattern; using ShiftLeftOpLowering = OneToOneConvertToLLVMPattern; using SignedDivIOpLowering = @@ -1687,6 +1691,8 @@ using UnsignedShiftRightOpLowering = OneToOneConvertToLLVMPattern; using XOrOpLowering = VectorConvertToLLVMPattern; +using ZeroExtendIOpLowering = + VectorConvertToLLVMPattern; /// Lower `std.assert`. The default lowering calls the `abort` function if the /// assertion is violated and has no effect otherwise. The failure message is @@ -2366,17 +2372,17 @@ return handleMultidimensionalVectors( op.getOperation(), operands, *getTypeConverter(), - [&](Type llvmVectorTy, ValueRange operands) { + [&](Type llvm1DVectorTy, ValueRange operands) { auto splatAttr = SplatElementsAttr::get( mlir::VectorType::get( - {LLVM::getVectorNumElements(llvmVectorTy).getFixedValue()}, + {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, floatType), floatOne); auto one = - rewriter.create(loc, llvmVectorTy, splatAttr); + rewriter.create(loc, llvm1DVectorTy, splatAttr); auto sqrt = - rewriter.create(loc, llvmVectorTy, operands[0]); - return rewriter.create(loc, llvmVectorTy, one, sqrt); + rewriter.create(loc, llvm1DVectorTy, operands[0]); + return rewriter.create(loc, llvm1DVectorTy, one, sqrt); }, rewriter); } @@ -3050,21 +3056,11 @@ using Super::Super; }; -struct SignExtendIOpLowering - : public OneToOneConvertToLLVMPattern { - using Super::Super; -}; - struct TruncateIOpLowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; -struct ZeroExtendIOpLowering - : public OneToOneConvertToLLVMPattern { - using Super::Super; -}; - // Base class for LLVM IR lowering terminator operations with successors. template struct OneToOneLLVMTerminatorLowering @@ -3211,21 +3207,21 @@ auto loc = splatOp.getLoc(); auto vectorTypeInfo = extractNDVectorTypeInfo(resultType, *getTypeConverter()); - auto llvmArrayTy = vectorTypeInfo.llvmArrayTy; - auto llvmVectorTy = vectorTypeInfo.llvmVectorTy; - if (!llvmArrayTy || !llvmVectorTy) + auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy; + auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy; + if (!llvmNDVectorTy || !llvm1DVectorTy) return failure(); // Construct returned value. - Value desc = rewriter.create(loc, llvmArrayTy); + Value desc = rewriter.create(loc, llvmNDVectorTy); // Construct a 1-D vector with the splatted value that we insert in all the // places within the returned descriptor. - Value vdesc = rewriter.create(loc, llvmVectorTy); + Value vdesc = rewriter.create(loc, llvm1DVectorTy); auto zero = rewriter.create( loc, typeConverter->convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); - Value v = rewriter.create(loc, llvmVectorTy, vdesc, + Value v = rewriter.create(loc, llvm1DVectorTy, vdesc, adaptor.input(), zero); // Shuffle the value across the desired number of elements. @@ -3237,7 +3233,7 @@ // Iterate of linear index, convert to coords space and insert splatted 1-D // vector in each position. nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) { - desc = rewriter.create(loc, llvmArrayTy, desc, v, + desc = rewriter.create(loc, llvmNDVectorTy, desc, v, position); }); rewriter.replaceOp(splatOp, desc); diff --git a/mlir/test/Conversion/StandardToLLVM/convert-nd-vector-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-nd-vector-to-llvmir.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/StandardToLLVM/convert-nd-vector-to-llvmir.mlir @@ -0,0 +1,49 @@ +// RUN: mlir-opt -convert-std-to-llvm %s -split-input-file | FileCheck %s + +// CHECK-LABEL: @vec_bin +func @vec_bin(%arg0: vector<2x2x2xf32>) -> vector<2x2x2xf32> { + %0 = addf %arg0, %arg0 : vector<2x2x2xf32> + return %0 : vector<2x2x2xf32> + +// CHECK-NEXT: llvm.mlir.undef : !llvm.array<2 x array<2 x vector<2xf32>>> + +// This block appears 2x2 times +// CHECK-NEXT: llvm.extractvalue %{{.*}}[0, 0] : !llvm.array<2 x array<2 x vector<2xf32>>> +// CHECK-NEXT: llvm.extractvalue %{{.*}}[0, 0] : !llvm.array<2 x array<2 x vector<2xf32>>> +// CHECK-NEXT: llvm.fadd %{{.*}} : vector<2xf32> +// CHECK-NEXT: llvm.insertvalue %{{.*}}[0, 0] : !llvm.array<2 x array<2 x vector<2xf32>>> + +// We check the proper indexing of extract/insert in the remaining 3 positions. +// CHECK: llvm.extractvalue %{{.*}}[0, 1] : !llvm.array<2 x array<2 x vector<2xf32>>> +// CHECK: llvm.insertvalue %{{.*}}[0, 1] : !llvm.array<2 x array<2 x vector<2xf32>>> +// CHECK: llvm.extractvalue %{{.*}}[1, 0] : !llvm.array<2 x array<2 x vector<2xf32>>> +// CHECK: llvm.insertvalue %{{.*}}[1, 0] : !llvm.array<2 x array<2 x vector<2xf32>>> +// CHECK: llvm.extractvalue %{{.*}}[1, 1] : !llvm.array<2 x array<2 x vector<2xf32>>> +// CHECK: llvm.insertvalue %{{.*}}[1, 1] : !llvm.array<2 x array<2 x vector<2xf32>>> +} + +// CHECK-LABEL: @sexti +func @sexti_vector(%arg0 : vector<1x2x3xi32>, %arg1 : vector<1x2x3xi64>) { + // CHECK: llvm.mlir.undef : !llvm.array<1 x array<2 x vector<3xi64>>> + // CHECK: llvm.extractvalue %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xi32>>> + // CHECK: llvm.sext %{{.*}} : vector<3xi32> to vector<3xi64> + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xi64>>> + // CHECK: llvm.extractvalue %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xi32>>> + // CHECK: llvm.sext %{{.*}} : vector<3xi32> to vector<3xi64> + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xi64>>> + %0 = sexti %arg0: vector<1x2x3xi32> to vector<1x2x3xi64> + return +} + +// CHECK-LABEL: @zexti +func @zexti_vector(%arg0 : vector<1x2x3xi32>, %arg1 : vector<1x2x3xi64>) { + // CHECK: llvm.mlir.undef : !llvm.array<1 x array<2 x vector<3xi64>>> + // CHECK: llvm.extractvalue %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xi32>>> + // CHECK: llvm.zext %{{.*}} : vector<3xi32> to vector<3xi64> + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xi64>>> + // CHECK: llvm.extractvalue %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xi32>>> + // CHECK: llvm.zext %{{.*}} : vector<3xi32> to vector<3xi64> + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xi64>>> + %0 = zexti %arg0: vector<1x2x3xi32> to vector<1x2x3xi64> + return +} diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -766,31 +766,6 @@ return } -// CHECK-LABEL: @vec_bin -func @vec_bin(%arg0: vector<2x2x2xf32>) -> vector<2x2x2xf32> { - %0 = addf %arg0, %arg0 : vector<2x2x2xf32> - return %0 : vector<2x2x2xf32> - -// CHECK-NEXT: llvm.mlir.undef : !llvm.array<2 x array<2 x vector<2xf32>>> - -// This block appears 2x2 times -// CHECK-NEXT: llvm.extractvalue %{{.*}}[0, 0] : !llvm.array<2 x array<2 x vector<2xf32>>> -// CHECK-NEXT: llvm.extractvalue %{{.*}}[0, 0] : !llvm.array<2 x array<2 x vector<2xf32>>> -// CHECK-NEXT: llvm.fadd %{{.*}} : vector<2xf32> -// CHECK-NEXT: llvm.insertvalue %{{.*}}[0, 0] : !llvm.array<2 x array<2 x vector<2xf32>>> - -// We check the proper indexing of extract/insert in the remaining 3 positions. -// CHECK: llvm.extractvalue %{{.*}}[0, 1] : !llvm.array<2 x array<2 x vector<2xf32>>> -// CHECK: llvm.insertvalue %{{.*}}[0, 1] : !llvm.array<2 x array<2 x vector<2xf32>>> -// CHECK: llvm.extractvalue %{{.*}}[1, 0] : !llvm.array<2 x array<2 x vector<2xf32>>> -// CHECK: llvm.insertvalue %{{.*}}[1, 0] : !llvm.array<2 x array<2 x vector<2xf32>>> -// CHECK: llvm.extractvalue %{{.*}}[1, 1] : !llvm.array<2 x array<2 x vector<2xf32>>> -// CHECK: llvm.insertvalue %{{.*}}[1, 1] : !llvm.array<2 x array<2 x vector<2xf32>>> - -// And we're done -// CHECK-NEXT: return -} - // CHECK-LABEL: @splat // CHECK-SAME: %[[A:arg[0-9]+]]: vector<4xf32> // CHECK-SAME: %[[ELT:arg[0-9]+]]: f32