diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -157,6 +157,13 @@ LogicalResult matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // Speical case for inserting scalar values into size-1 vectors. + if (insertOp.getSourceType().isIntOrFloat() && + insertOp.getDestVectorType().getNumElements() == 1) { + rewriter.replaceOp(insertOp, adaptor.source()); + return success(); + } + if (insertOp.getSourceType().isa() || !spirv::CompositeType::isValid(insertOp.getDestVectorType())) return failure(); diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp @@ -22,13 +22,13 @@ using namespace mlir; namespace { -struct LowerVectorToSPIRVPass - : public ConvertVectorToSPIRVBase { +struct ConvertVectorToSPIRVPass + : public ConvertVectorToSPIRVBase { void runOnOperation() override; }; } // namespace -void LowerVectorToSPIRVPass::runOnOperation() { +void ConvertVectorToSPIRVPass::runOnOperation() { MLIRContext *context = &getContext(); ModuleOp module = getOperation(); @@ -37,17 +37,26 @@ SPIRVConversionTarget::get(targetAttr); SPIRVTypeConverter typeConverter(targetAttr); + + // Use UnrealizedConversionCast as the bridge so that we don't need to pull in + // patterns for other dialects. + auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs, + Location loc) { + auto cast = builder.create(loc, type, inputs); + return Optional(cast.getResult(0)); + }; + typeConverter.addSourceMaterialization(addUnrealizedCast); + typeConverter.addTargetMaterialization(addUnrealizedCast); + target->addLegalOp(); + RewritePatternSet patterns(context); populateVectorToSPIRVPatterns(typeConverter, patterns); - target->addLegalOp(); - target->addLegalOp(); - - if (failed(applyFullConversion(module, *target, std::move(patterns)))) + if (failed(applyPartialConversion(module, *target, std::move(patterns)))) return signalPassFailure(); } std::unique_ptr> mlir::createConvertVectorToSPIRVPass() { - return std::make_unique(); + return std::make_unique(); } diff --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir --- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir @@ -2,152 +2,158 @@ module attributes { spv.target_env = #spv.target_env<#spv.vce, {}> } { -// CHECK-LABEL: func @bitcast +// CHECK-LABEL: @bitcast // CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf16> -// CHECK: %{{.+}} = spv.Bitcast %[[ARG0]] : vector<2xf32> to vector<4xf16> -// CHECK: %{{.+}} = spv.Bitcast %[[ARG1]] : vector<2xf16> to f32 -func @bitcast(%arg0 : vector<2xf32>, %arg1: vector<2xf16>) { +// CHECK: spv.Bitcast %[[ARG0]] : vector<2xf32> to vector<4xf16> +// CHECK: spv.Bitcast %[[ARG1]] : vector<2xf16> to f32 +func @bitcast(%arg0 : vector<2xf32>, %arg1: vector<2xf16>) -> (vector<4xf16>, vector<1xf32>) { %0 = vector.bitcast %arg0 : vector<2xf32> to vector<4xf16> %1 = vector.bitcast %arg1 : vector<2xf16> to vector<1xf32> - spv.Return + return %0, %1: vector<4xf16>, vector<1xf32> } } // end module // ----- -// CHECK-LABEL: broadcast +// CHECK-LABEL: @broadcast // CHECK-SAME: %[[A:.*]]: f32 // CHECK: spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] : vector<4xf32> // CHECK: spv.CompositeConstruct %[[A]], %[[A]] : vector<2xf32> -func @broadcast(%arg0 : f32) { +func @broadcast(%arg0 : f32) -> (vector<4xf32>, vector<2xf32>) { %0 = vector.broadcast %arg0 : f32 to vector<4xf32> %1 = vector.broadcast %arg0 : f32 to vector<2xf32> - spv.Return + return %0, %1: vector<4xf32>, vector<2xf32> } // ----- -// CHECK-LABEL: func @extract +// CHECK-LABEL: @extract // CHECK-SAME: %[[ARG:.+]]: vector<2xf32> -// CHECK: %{{.+}} = spv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32> -// CHECK: %{{.+}} = spv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32> -func @extract(%arg0 : vector<2xf32>) { +// CHECK: spv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32> +// CHECK: spv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32> +func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) { %0 = "vector.extract"(%arg0) {position = [0]} : (vector<2xf32>) -> vector<1xf32> %1 = "vector.extract"(%arg0) {position = [1]} : (vector<2xf32>) -> f32 - spv.Return + return %0, %1: vector<1xf32>, f32 } // ----- -module attributes { spv.target_env = #spv.target_env<#spv.vce, {}> } { - -// CHECK-LABEL: func @extract_scalar -// CHECK-SAME: %[[ARG0:.+]]: vector<2xf16> -// CHECK-SAME: %[[ARG1:.+]]: vector<4xf32> -// CHECK: %[[S:.+]] = spv.Bitcast %[[ARG0]] : vector<2xf16> to f32 -// CHECK: spv.CompositeInsert %[[S]], %[[ARG1]][0 : i32] : f32 into vector<4xf32> -func @extract_scalar(%arg0 : vector<2xf16>, %arg1 : vector<4xf32>) { - %0 = vector.bitcast %arg0 : vector<2xf16> to vector<1xf32> - %1 = vector.extract %0[0] : vector<1xf32> - %2 = vector.insert %1, %arg1[0] : f32 into vector<4xf32> - spv.Return +// CHECK-LABEL: @extract_size1_vector +// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32> +// CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] +// CHECK: return %[[R]] +func @extract_size1_vector(%arg0 : vector<1xf32>) -> f32 { + %0 = vector.extract %arg0[0] : vector<1xf32> + return %0: f32 } -} // end module +// ----- + +// CHECK-LABEL: @insert +// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[S:.*]]: f32 +// CHECK: spv.CompositeInsert %[[S]], %[[V]][2 : i32] : f32 into vector<4xf32> +func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> { + %1 = vector.insert %arg1, %arg0[2] : f32 into vector<4xf32> + return %1: vector<4xf32> +} // ----- -// CHECK-LABEL: extract_insert -// CHECK-SAME: %[[V:.*]]: vector<4xf32> -// CHECK: %[[S:.*]] = spv.CompositeExtract %[[V]][1 : i32] : vector<4xf32> -// CHECK: spv.CompositeInsert %[[S]], %[[V]][0 : i32] : f32 into vector<4xf32> -func @extract_insert(%arg0 : vector<4xf32>) { - %0 = vector.extract %arg0[1] : vector<4xf32> - %1 = vector.insert %0, %arg0[0] : f32 into vector<4xf32> - spv.Return +// CHECK-LABEL: @insert_size1_vector +// CHECK-SAME: %[[V:.*]]: vector<1xf32>, %[[S:.*]]: f32 +// CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[S]] +// CHECK: return %[[R]] +func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf32> { + %1 = vector.insert %arg1, %arg0[0] : f32 into vector<1xf32> + return %1 : vector<1xf32> } // ----- -// CHECK-LABEL: extract_element +// CHECK-LABEL: @extract_element // CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 // CHECK: spv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32 -func @extract_element(%arg0 : vector<4xf32>, %id : i32) { +func @extract_element(%arg0 : vector<4xf32>, %id : i32) -> f32 { %0 = vector.extractelement %arg0[%id : i32] : vector<4xf32> - spv.ReturnValue %0: f32 + return %0: f32 } // ----- -func @extract_element_index(%arg0 : vector<4xf32>, %id : index) { -// expected-error @+1 {{failed to legalize operation 'vector.extractelement'}} +// CHECK-LABEL: @extract_element_index +func @extract_element_index(%arg0 : vector<4xf32>, %id : index) -> f32 { + // CHECK: vector.extractelement %0 = vector.extractelement %arg0[%id : index] : vector<4xf32> - spv.ReturnValue %0: f32 + return %0: f32 } // ----- -func @extract_element_negative(%arg0 : vector<5xf32>, %id : i32) { -// expected-error @+1 {{failed to legalize operation 'vector.extractelement'}} +// CHECK-LABEL: @extract_element_size5_vector +func @extract_element_size5_vector(%arg0 : vector<5xf32>, %id : i32) -> f32 { + // CHECK: vector.extractelement %0 = vector.extractelement %arg0[%id : i32] : vector<5xf32> - spv.ReturnValue %0: f32 + return %0: f32 } // ----- -// CHECK-LABEL: func @extract_strided_slice +// CHECK-LABEL: @extract_strided_slice // CHECK-SAME: %[[ARG:.+]]: vector<4xf32> -// CHECK: %{{.+}} = spv.VectorShuffle [1 : i32, 2 : i32] %[[ARG]] : vector<4xf32>, %[[ARG]] : vector<4xf32> -> vector<2xf32> -// CHECK: %{{.+}} = spv.CompositeExtract %[[ARG]][1 : i32] : vector<4xf32> -func @extract_strided_slice(%arg0: vector<4xf32>) { +// CHECK: spv.VectorShuffle [1 : i32, 2 : i32] %[[ARG]] : vector<4xf32>, %[[ARG]] : vector<4xf32> -> vector<2xf32> +// CHECK: spv.CompositeExtract %[[ARG]][1 : i32] : vector<4xf32> +func @extract_strided_slice(%arg0: vector<4xf32>) -> (vector<2xf32>, vector<1xf32>) { %0 = vector.extract_strided_slice %arg0 {offsets = [1], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> %1 = vector.extract_strided_slice %arg0 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> - spv.Return + return %0, %1 : vector<2xf32>, vector<1xf32> } // ----- -// CHECK-LABEL: insert_element +// CHECK-LABEL: @insert_element // CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 // CHECK: spv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32 -func @insert_element(%val: f32, %arg0 : vector<4xf32>, %id : i32) { +func @insert_element(%val: f32, %arg0 : vector<4xf32>, %id : i32) -> vector<4xf32> { %0 = vector.insertelement %val, %arg0[%id : i32] : vector<4xf32> - spv.ReturnValue %0: vector<4xf32> + return %0: vector<4xf32> } // ----- -func @insert_element_index(%val: f32, %arg0 : vector<4xf32>, %id : index) { -// expected-error @+1 {{failed to legalize operation 'vector.insertelement'}} +// CHECK-LABEL: @insert_element_index +func @insert_element_index(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> { + // CHECK: vector.insertelement %0 = vector.insertelement %val, %arg0[%id : index] : vector<4xf32> - spv.ReturnValue %0: vector<4xf32> + return %0: vector<4xf32> } // ----- -func @insert_element_negative(%val: f32, %arg0 : vector<5xf32>, %id : i32) { -// expected-error @+1 {{failed to legalize operation 'vector.insertelement'}} +// CHECK-LABEL: @insert_element_size5_vector +func @insert_element_size5_vector(%val: f32, %arg0 : vector<5xf32>, %id : i32) -> vector<5xf32> { + // CHECK: vector.insertelement %0 = vector.insertelement %val, %arg0[%id : i32] : vector<5xf32> - spv.Return + return %0 : vector<5xf32> } // ----- -// CHECK-LABEL: func @insert_strided_slice +// CHECK-LABEL: @insert_strided_slice // CHECK-SAME: %[[PART:.+]]: vector<2xf32>, %[[ALL:.+]]: vector<4xf32> -// CHECK: %{{.+}} = spv.VectorShuffle [0 : i32, 4 : i32, 5 : i32, 3 : i32] %[[ALL]] : vector<4xf32>, %[[PART]] : vector<2xf32> -> vector<4xf32> -func @insert_strided_slice(%arg0: vector<2xf32>, %arg1: vector<4xf32>) { +// CHECK: spv.VectorShuffle [0 : i32, 4 : i32, 5 : i32, 3 : i32] %[[ALL]] : vector<4xf32>, %[[PART]] : vector<2xf32> -> vector<4xf32> +func @insert_strided_slice(%arg0: vector<2xf32>, %arg1: vector<4xf32>) -> vector<4xf32> { %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [1], strides = [1]} : vector<2xf32> into vector<4xf32> - spv.Return + return %0 : vector<4xf32> } // ----- -// CHECK-LABEL: func @fma +// CHECK-LABEL: @fma // CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32> // CHECK: spv.GLSL.Fma %[[A]], %[[B]], %[[C]] : vector<4xf32> -func @fma(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) { +func @fma(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) -> vector<4xf32> { %0 = vector.fma %a, %b, %c: vector<4xf32> - spv.Return + return %0 : vector<4xf32> }