Index: mlir/include/mlir/Conversion/Passes.h =================================================================== --- mlir/include/mlir/Conversion/Passes.h +++ mlir/include/mlir/Conversion/Passes.h @@ -29,6 +29,7 @@ #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h" #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" +#include "mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h" namespace mlir { Index: mlir/include/mlir/Conversion/Passes.td =================================================================== --- mlir/include/mlir/Conversion/Passes.td +++ mlir/include/mlir/Conversion/Passes.td @@ -381,4 +381,15 @@ let dependentDialects = ["ROCDL::ROCDLDialect"]; } +//===----------------------------------------------------------------------===// +// VectorToSPIRV +//===----------------------------------------------------------------------===// + +def ConvertVectorToSPIRV : Pass<"convert-vector-to-spirv", "ModuleOp"> { + let summary = "Lower the operations from the vector dialect into the SPIR-V " + "dialect"; + let constructor = "mlir::createConvertVectorToSPIRVPass()"; + let dependentDialects = ["spirv::SPIRVDialect"]; +} + #endif // MLIR_CONVERSION_PASSES Index: mlir/include/mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRV.h =================================================================== --- /dev/null +++ mlir/include/mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRV.h @@ -0,0 +1,29 @@ +//=- ConvertVectorToSPIRV.h - Vector Ops to SPIR-V dialect patterns - C++ -*-=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides patterns for lowering Vector Ops to SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INCLUDE_MLIR_CONVERSION_VECTORTOSPIRV_CONVERTVECTORTOSPIRV_H_ +#define MLIR_INCLUDE_MLIR_CONVERSION_VECTORTOSPIRV_CONVERTVECTORTOSPIRV_H_ + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +class SPIRVTypeConverter; + +/// Appends to a pattern list additional patterns for translating Vector Ops to +/// SPIR-V ops. +void populateVectorToSPIRVPatterns(MLIRContext *context, + SPIRVTypeConverter &typeConverter, + OwningRewritePatternList &patterns); + +} // namespace mlir + +#endif // MLIR_INCLUDE_MLIR_CONVERSION_VECTORTOSPIRV_CONVERTVECTORTOSPIRV_H_ Index: mlir/include/mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h =================================================================== --- /dev/null +++ mlir/include/mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h @@ -0,0 +1,25 @@ +//=- ConvertVectorToSPIRVPass.h - Pass converting Vector to SPIRV -*- C++ -*-=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides a pass to convert Vector ops to SPIR-V ops. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_VECTORTOSPIRV_CONVERTGPUTOSPIRVPASS_H +#define MLIR_CONVERSION_VECTORTOSPIRV_CONVERTGPUTOSPIRVPASS_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +/// Pass to convert Vector Ops to SPIR-V ops. +std::unique_ptr> createConvertVectorToSPIRVPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_VECTORTOSPIRV_CONVERTGPUTOSPIRVPASS_H Index: mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td =================================================================== --- mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td +++ mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td @@ -161,6 +161,11 @@ let results = (outs SPV_Composite:$result ); + + let builders = [ + OpBuilder<[{OpBuilder &builder, OperationState &state, Value object, + Value composite, ArrayRef indices}]> + ]; } #endif // SPIRV_COMPOSITE_OPS Index: mlir/lib/Conversion/CMakeLists.txt =================================================================== --- mlir/lib/Conversion/CMakeLists.txt +++ mlir/lib/Conversion/CMakeLists.txt @@ -19,3 +19,4 @@ add_subdirectory(VectorToROCDL) add_subdirectory(VectorToLLVM) add_subdirectory(VectorToSCF) +add_subdirectory(VectorToSPIRV) Index: mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt @@ -0,0 +1,15 @@ +add_mlir_conversion_library(MLIRVectorToSPIRV + VectorToSPIRV.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToSPIRV + + DEPENDS + MLIRConversionPassIncGen + intrinsics_gen + + LINK_LIBS PUBLIC + MLIRSPIRV + MLIRVector + MLIRTransforms + ) Index: mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp =================================================================== --- /dev/null +++ mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -0,0 +1,119 @@ +//===------- VectorToSPIRV.cpp - Vector to SPIRV lowering passes ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to generate SPIRV operations for Vector +// operations. +// +//===----------------------------------------------------------------------===// + +#include "../PassDetail.h" +#include "mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRV.h" +#include "mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h" +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/SPIRVLowering.h" +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/SPIRVTypes.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { +struct VectorBroadcastConvert final + : public SPIRVOpLowering { + using SPIRVOpLowering::SPIRVOpLowering; + LogicalResult + matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (broadcastOp.source().getType().isa() || + !spirv::CompositeType::isValid(broadcastOp.getVectorType())) + return failure(); + vector::BroadcastOp::Adaptor adaptor(operands); + SmallVector source(broadcastOp.getVectorType().getNumElements(), + adaptor.source()); + Value construct = rewriter.create( + broadcastOp.getLoc(), broadcastOp.getVectorType(), source); + rewriter.replaceOp(broadcastOp, construct); + return success(); + } +}; + +struct VectorExtractOpConvert final + : public SPIRVOpLowering { + using SPIRVOpLowering::SPIRVOpLowering; + LogicalResult + matchAndRewrite(vector::ExtractOp extractOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (extractOp.getType().isa() || + !spirv::CompositeType::isValid(extractOp.getVectorType())) + return failure(); + vector::ExtractOp::Adaptor adaptor(operands); + int32_t id = extractOp.position().begin()->cast().getInt(); + Value newExtract = rewriter.create( + extractOp.getLoc(), adaptor.vector(), id); + rewriter.replaceOp(extractOp, newExtract); + return success(); + } +}; + +struct VectorInsertOpConvert final : public SPIRVOpLowering { + using SPIRVOpLowering::SPIRVOpLowering; + LogicalResult + matchAndRewrite(vector::InsertOp insertOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (insertOp.getSourceType().isa() || + !spirv::CompositeType::isValid(insertOp.getDestVectorType())) + return failure(); + vector::InsertOp::Adaptor adaptor(operands); + int32_t id = insertOp.position().begin()->cast().getInt(); + Value newInsert = rewriter.create( + insertOp.getLoc(), adaptor.source(), adaptor.dest(), id); + rewriter.replaceOp(insertOp, newInsert); + return success(); + } +}; +} // namespace + +void mlir::populateVectorToSPIRVPatterns(MLIRContext *context, + SPIRVTypeConverter &typeConverter, + OwningRewritePatternList &patterns) { + patterns.insert(context, typeConverter); +} + +namespace { +struct LowerVectorToSPIRVPass + : public ConvertVectorToSPIRVBase { + void runOnOperation() override; +}; +} // namespace + +void LowerVectorToSPIRVPass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + auto targetAttr = spirv::lookupTargetEnvOrDefault(module); + std::unique_ptr target = + spirv::SPIRVConversionTarget::get(targetAttr); + + SPIRVTypeConverter typeConverter(targetAttr); + OwningRewritePatternList patterns; + populateVectorToSPIRVPatterns(context, typeConverter, patterns); + + target->addLegalOp(); + target->addLegalOp(); + + if (failed(applyFullConversion(module, *target, patterns))) + return signalPassFailure(); +} + +std::unique_ptr> +mlir::createConvertVectorToSPIRVPass() { + return std::make_unique(); +} Index: mlir/lib/Dialect/SPIRV/SPIRVOps.cpp =================================================================== --- mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1410,6 +1410,13 @@ // spv.CompositeInsert //===----------------------------------------------------------------------===// +void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state, + Value object, Value composite, + ArrayRef indices) { + auto indexAttr = builder.getI32ArrayAttr(indices); + build(builder, state, composite.getType(), object, composite, indexAttr); +} + static ParseResult parseCompositeInsertOp(OpAsmParser &parser, OperationState &state) { SmallVector operands; Index: mlir/test/Conversion/VectorToSPIRV/simple.mlir =================================================================== --- /dev/null +++ mlir/test/Conversion/VectorToSPIRV/simple.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt -split-input-file -convert-vector-to-spirv %s -o - | FileCheck %s + +// CHECK-LABEL: broadcast +// CHECK-SAME: %[[A:.*]]: f32 +// CHECK: spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] : vector<4xf32> +// CHECK: spv.CompositeConstruct %[[A]], %[[A]] : vector<2xf32> +func @broadcast(%arg0 : f32) { + %0 = vector.broadcast %arg0 : f32 to vector<4xf32> + %1 = vector.broadcast %arg0 : f32 to vector<2xf32> + spv.Return +} + +// ----- + +// CHECK-LABEL: extract_insert +// CHECK-SAME: %[[V:.*]]: vector<4xf32> +// CHECK: %[[S:.*]] = spv.CompositeExtract %[[V]][1 : i32] : vector<4xf32> +// CHECK: spv.CompositeInsert %[[S]], %[[V]][0 : i32] : f32 into vector<4xf32> +func @extract_insert(%arg0 : vector<4xf32>) { + %0 = vector.extract %arg0[1] : vector<4xf32> + %1 = vector.insert %0, %arg0[0] : f32 into vector<4xf32> + spv.Return +}