diff --git a/mlir/include/mlir/Dialect/SPIRV/Passes.h b/mlir/include/mlir/Dialect/SPIRV/Passes.h --- a/mlir/include/mlir/Dialect/SPIRV/Passes.h +++ b/mlir/include/mlir/Dialect/SPIRV/Passes.h @@ -46,6 +46,10 @@ /// functions using the specification in the `spv.entry_point_abi` attribute. std::unique_ptr> createLowerABIAttributesPass(); +/// Creates an operation pass that rewrites sequential chains of +/// spv.CompositeInsert into spv.CompositeConstruct. +std::unique_ptr> createRewriteInsertsPass(); + } // namespace spirv } // namespace mlir diff --git a/mlir/include/mlir/Dialect/SPIRV/Passes.td b/mlir/include/mlir/Dialect/SPIRV/Passes.td --- a/mlir/include/mlir/Dialect/SPIRV/Passes.td +++ b/mlir/include/mlir/Dialect/SPIRV/Passes.td @@ -22,6 +22,12 @@ let constructor = "mlir::spirv::createLowerABIAttributesPass()"; } +def SPIRVRewriteInsertsPass : Pass<"spirv-rewrite-inserts", "spirv::ModuleOp"> { + let summary = "Rewrite sequential chains of spv.CompositeInsert operations into " + "spv.CompositeConstruct operations"; + let constructor = "mlir::spirv::createRewriteInsertsPass()"; +} + def SPIRVUpdateVCE : Pass<"spirv-update-vce", "spirv::ModuleOp"> { let summary = "Deduce and attach minimal (version, capabilities, extensions) " "requirements to spv.module ops"; diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRSPIRVTransforms DecorateSPIRVCompositeTypeLayoutPass.cpp LowerABIAttributesPass.cpp + RewriteInsertsPass.cpp UpdateVCEPass.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp @@ -0,0 +1,117 @@ +//===- RewriteInsertsPass.cpp - MLIR conversion pass ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to rewrite sequential chains of +// `spirv::CompositeInsert` operations into `spirv::CompositeConstruct` +// operations. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/SPIRV/Passes.h" +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Module.h" + +using namespace mlir; + +namespace { + +/// Replaces sequential chains of `spirv::CompositeInsertOp` operation into +/// `spirv::CompositeConstructOp` operation if possible. +class RewriteInsertsPass + : public SPIRVRewriteInsertsPassBase { +public: + void runOnOperation() override; + +private: + /// Collects a sequential insertion chain by the given + /// `spirv::CompositeInsertOp` operation, if the given operation is the last + /// in the chain. + LogicalResult + collectInsertionChain(spirv::CompositeInsertOp op, + SmallVectorImpl &insertions); +}; + +} // anonymous namespace + +void RewriteInsertsPass::runOnOperation() { + SmallVector, 4> workList; + getOperation().walk([this, &workList](spirv::CompositeInsertOp op) { + SmallVector insertions; + if (succeeded(collectInsertionChain(op, insertions))) + workList.push_back(insertions); + }); + + for (const auto &insertions : workList) { + auto lastCompositeInsertOp = insertions.back(); + auto compositeType = lastCompositeInsertOp.getType(); + auto location = lastCompositeInsertOp.getLoc(); + + SmallVector operands; + // Collect inserted objects. + for (auto insertionOp : insertions) + operands.push_back(insertionOp.object()); + + OpBuilder builder(lastCompositeInsertOp); + auto compositeConstructOp = builder.create( + location, compositeType, operands); + + lastCompositeInsertOp.replaceAllUsesWith( + compositeConstructOp.getOperation()->getResult(0)); + + // Erase ops. + for (auto insertOp : insertions) { + auto *op = insertOp.getOperation(); + if (op->hasOneUse() || op->use_empty()) { + insertOp.erase(); + } + } + } +} + +LogicalResult RewriteInsertsPass::collectInsertionChain( + spirv::CompositeInsertOp op, + SmallVectorImpl &insertions) { + auto indicesArrayAttr = op.indices().cast(); + // TODO(denis0x0D): handle nested composite object. + if (indicesArrayAttr.size() == 1) { + auto numElements = + op.composite().getType().cast().getNumElements(); + + auto index = indicesArrayAttr[0].cast().getInt(); + // Need a last index to collect a sequential chain. + if (index + 1 != numElements) + return failure(); + + insertions.resize(numElements); + while (true) { + insertions[index] = op; + + if (index == 0) + return success(); + + op = dyn_cast_or_null( + op.composite().getDefiningOp()); + if (!op) + return failure(); + + --index; + indicesArrayAttr = op.indices().cast(); + if ((indicesArrayAttr.size() != 1) || + (indicesArrayAttr[0].cast().getInt() != index)) + return failure(); + } + } + return failure(); +} + +std::unique_ptr> +mlir::spirv::createRewriteInsertsPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir b/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir @@ -0,0 +1,31 @@ +// RUN: mlir-opt -spirv-rewrite-inserts -split-input-file -verify-diagnostics %s -o - | FileCheck %s + +spv.module Logical GLSL450 { + spv.func @rewrite(%value0 : f32, %value1 : f32, %value2 : f32, %value3 : i32, %value4: !spv.array<3xf32>) -> vector<3xf32> "None" { + %0 = spv.undef : vector<3xf32> + // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : vector<3xf32> + %1 = spv.CompositeInsert %value0, %0[0 : i32] : f32 into vector<3xf32> + %2 = spv.CompositeInsert %value1, %1[1 : i32] : f32 into vector<3xf32> + %3 = spv.CompositeInsert %value2, %2[2 : i32] : f32 into vector<3xf32> + + %4 = spv.undef : !spv.array<4xf32> + // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : !spv.array<4 x f32> + %5 = spv.CompositeInsert %value0, %4[0 : i32] : f32 into !spv.array<4xf32> + %6 = spv.CompositeInsert %value1, %5[1 : i32] : f32 into !spv.array<4xf32> + %7 = spv.CompositeInsert %value2, %6[2 : i32] : f32 into !spv.array<4xf32> + %8 = spv.CompositeInsert %value0, %7[3 : i32] : f32 into !spv.array<4xf32> + + %9 = spv.undef : !spv.struct + // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : !spv.struct + %10 = spv.CompositeInsert %value0, %9[0 : i32] : f32 into !spv.struct + %11 = spv.CompositeInsert %value3, %10[1 : i32] : i32 into !spv.struct + %12 = spv.CompositeInsert %value1, %11[2 : i32] : f32 into !spv.struct + + %13 = spv.undef : !spv.struct> + // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}} : !spv.struct> + %14 = spv.CompositeInsert %value0, %13[0 : i32] : f32 into !spv.struct> + %15 = spv.CompositeInsert %value4, %14[1 : i32] : !spv.array<3xf32> into !spv.struct> + + spv.ReturnValue %3 : vector<3xf32> + } +}