diff --git a/mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h b/mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h @@ -0,0 +1,29 @@ +//===- LinalgToSPIRV.h - Linalg to SPIR-V dialect conversion ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides patterns for Linalg to SPIR-V dialect conversion. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_LINALGTOSPIRV_LINALGTOSPIRV_H +#define MLIR_CONVERSION_LINALGTOSPIRV_LINALGTOSPIRV_H + +namespace mlir { +class MLIRContext; +class OwningRewritePatternList; +class SPIRVTypeConverter; + +/// Appends to a pattern list additional patterns for translating Linalg ops to +/// SPIR-V ops. +void populateLinalgToSPIRVPatterns(MLIRContext *context, + SPIRVTypeConverter &typeConverter, + OwningRewritePatternList &patterns); + +} // namespace mlir + +#endif // MLIR_CONVERSION_LINALGTOSPIRV_LINALGTOSPIRV_H diff --git a/mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h b/mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h @@ -0,0 +1,25 @@ +//===- LinalgToSPIRVPass.h - Linalg to SPIR-V conversion pass --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides a pass for Linalg to SPIR-V dialect conversion. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_STANDARDTOSPIRV_LINALGTOSPIRVPASS_H +#define MLIR_CONVERSION_STANDARDTOSPIRV_LINALGTOSPIRVPASS_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +/// Creates and returns a pass to convert Linalg ops to SPIR-V ops. +std::unique_ptr> createLinalgToSPIRVPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_STANDARDTOSPIRV_LINALGTOSPIRVPASS_H diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td @@ -25,6 +25,7 @@ SPV_ScopeAttr:$memory_scope, SPV_MemorySemanticsAttr:$semantics ); + let results = (outs SPV_Integer:$result ); @@ -42,9 +43,19 @@ SPV_MemorySemanticsAttr:$semantics, SPV_Integer:$value ); + let results = (outs SPV_Integer:$result ); + + let builders = [ + OpBuilder< + [{Builder *builder, OperationState &state, Value pointer, + ::mlir::spirv::Scope scope, ::mlir::spirv::MemorySemantics memory, + Value value}], + [{build(builder, state, value.getType(), pointer, scope, memory, value);}] + > + ]; } // ----- diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td @@ -446,14 +446,22 @@ let regions = (region AnyRegion:$body); let extraClassDeclaration = [{ - // Returns the selection header block. + /// Returns the selection header block. Block *getHeaderBlock(); - // Returns the selection merge block. + /// Returns the selection merge block. Block *getMergeBlock(); - // Adds a selection merge block containing one spv._merge op. + /// Adds a selection merge block containing one spv._merge op. void addMergeBlock(); + + /// Creates a spv.selection op for `if () then { }` + /// with `builder`. `builder`'s insertion point will remain at after the + /// newly inserted spv.selection op afterwards. + static SelectionOp createIfThen( + Location loc, Value condition, + llvm::function_ref thenBody, + OpBuilder *builder); }]; let hasOpcode = 0; diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h @@ -58,6 +58,8 @@ OwningRewritePatternList &patterns); namespace spirv { +class AccessChainOp; + class SPIRVConversionTarget : public ConversionTarget { public: /// Creates a SPIR-V conversion target for the given target environment. @@ -90,6 +92,16 @@ Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, OpBuilder &builder); +/// Performs the index computation to get to the element at `indices` of the +/// memory pointed to by `basePtr`, using the layout map of `baseType`. + +// TODO(ravishankarm) : This method assumes that the `baseType` is a MemRefType +// with AffineMap that has static strides. Extend to handle dynamic strides. +spirv::AccessChainOp getElementPtr(SPIRVTypeConverter &typeConverter, + MemRefType baseType, Value basePtr, + ArrayRef indices, Location loc, + OpBuilder &builder); + /// Sets the InterfaceVarABIAttr and EntryPointABIAttr for a function and its /// arguments. LogicalResult setABIAttrs(FuncOp funcOp, EntryPointABIAttr entryPointInfo, diff --git a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h --- a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h +++ b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h @@ -54,6 +54,12 @@ /// target environment (SPIR-V 1.0 with Shader capability and no extra /// extensions) if not provided. TargetEnvAttr lookupTargetEnvOrDefault(Operation *op); + +/// Queries the local workgroup size from entry point ABI on the nearest +/// function-like op containing the given `op`. Returns null attribute if not +/// found. +DenseIntElementsAttr lookupLocalWorkGroupSize(Operation *op); + } // namespace spirv } // namespace mlir diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -4,6 +4,7 @@ add_subdirectory(GPUToROCDL) add_subdirectory(GPUToSPIRV) add_subdirectory(LinalgToLLVM) +add_subdirectory(LinalgToSPIRV) add_subdirectory(LoopsToGPU) add_subdirectory(LoopToStandard) add_subdirectory(StandardToLLVM) diff --git a/mlir/lib/Conversion/LinalgToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/LinalgToSPIRV/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/LinalgToSPIRV/CMakeLists.txt @@ -0,0 +1,16 @@ +add_llvm_library(MLIRLinalgToSPIRVTransforms + LinalgToSPIRV.cpp + LinalgToSPIRVPass.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV + ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR + ) + +target_link_libraries(MLIRLinalgToSPIRVTransforms + MLIRIR + MLIRLinalgOps + MLIRPass + MLIRSPIRV + MLIRSupport + ) diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp @@ -0,0 +1,264 @@ +//===- LinalgToSPIRV.cpp - Linalg to SPIR-V dialect conversion ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/SPIRVLowering.h" +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Matchers.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Utilities +//===----------------------------------------------------------------------===// + +/// Returns true if all of the given `op`'s operands and values are of memref +/// types. +static bool areAllValuesMemref(Operation *op) { + auto isOfMemrefType = [](Value val) { + return val.getType().isa(); + }; + + return llvm::all_of(op->getOperands(), isOfMemrefType) && + llvm::all_of(op->getResults(), isOfMemrefType); +} + +/// Returns true if the given Linalg `iterators` is one reduction. +static bool isLinalgSingleReductionIterator(ArrayAttr iterators) { + if (iterators.getValue().size() != 1) + return false; + + auto iterator = (*iterators.begin()).dyn_cast(); + if (iterator.getValue() != getReductionIteratorTypeName()) + return false; + + return true; +} + +/// Returns a `Value` containing the `dim`-th dimension's size of SPIR-V +/// location invocation ID. This function will create necessary operations with +/// `builder` at the proper region containing `op`. +static Value getLocalInvocationDimSize(Operation *op, int dim, Location loc, + OpBuilder *builder) { + assert(dim >= 0 && dim < 3 && "local invocation only has three dimensions"); + Value invocation = spirv::getBuiltinVariableValue( + op, spirv::BuiltIn::LocalInvocationId, *builder); + Type xType = invocation.getType().cast().getElementType(); + return builder->create( + loc, xType, invocation, builder->getI32ArrayAttr({dim})); +} + +namespace { +enum class BinaryOpKind { + Unknown, + IAdd, +}; +} + +/// Returns the binary op kind if the given linalg.generic op has the following +/// body: +/// +/// ``` +/// linalg.generic ... { +/// ^bb(%a: , %b: ): +/// %0 = %a, %b: +/// linalg.yield %0: +/// } +/// ``` +static BinaryOpKind getScalarBinaryOpKind(linalg::GenericOp op) { + auto ®ion = op.region(); + if (region.empty() || !has_single_element(region.getBlocks())) + return BinaryOpKind::Unknown; + + Block &block = region.front(); + if (block.getNumArguments() != 2 || + !block.getArgument(0).getType().isIntOrFloat() || + !block.getArgument(1).getType().isIntOrFloat()) + return BinaryOpKind::Unknown; + + auto &ops = block.getOperations(); + if (!has_single_element(block.without_terminator())) + return BinaryOpKind::Unknown; + + using mlir::matchers::m_Val; + auto a = m_Val(block.getArgument(0)); + auto b = m_Val(block.getArgument(1)); + + auto addPattern = m_Op(m_Op(a, b)); + if (addPattern.match(&ops.back())) + return BinaryOpKind::IAdd; + + return BinaryOpKind::Unknown; +} + +//===----------------------------------------------------------------------===// +// Reduction (single workgroup) +//===----------------------------------------------------------------------===// + +namespace { + +/// A pattern to convert a linalg.generic op to SPIR-V ops under the condition +/// that the linalg.generic op is performing reduction with a workload size that +/// can fit in one workgroup. +class SingleWorkgroupReduction final + : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(linalg::GenericOp genericOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + +} // namespace + +PatternMatchResult SingleWorkgroupReduction::matchAndRewrite( + linalg::GenericOp genericOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + Operation *op = genericOp.getOperation(); + + // Make sure the linalg.generic is working on memrefs. + if (!areAllValuesMemref(op)) + return matchFailure(); + + // Make sure this is reudction with one input and one output. + if (genericOp.args_in().getZExtValue() != 1 || + genericOp.args_out().getZExtValue() != 1) + return matchFailure(); + + auto originalInputType = op->getOperand(0).getType().cast(); + auto originalOutputType = op->getOperand(1).getType().cast(); + + // Make sure the original input has one dimension. + if (!originalInputType.hasStaticShape() || originalInputType.getRank() != 1) + return matchFailure(); + // Make sure the original output has one element. + if (!originalOutputType.hasStaticShape() || + originalOutputType.getNumElements() != 1) + return matchFailure(); + + if (!isLinalgSingleReductionIterator(genericOp.iterator_types())) + return matchFailure(); + + if (genericOp.indexing_maps().getValue().size() != 2) + return matchFailure(); + + auto inputMap = genericOp.indexing_maps().getValue()[0].cast(); + auto outputMap = + genericOp.indexing_maps().getValue()[1].cast(); + // The indexing map for the input should be `(i) -> (i)`. + if (inputMap.getValue() != + AffineMap::get(1, 0, {getAffineDimExpr(0, op->getContext())})) + return matchFailure(); + // The indexing map for the input should be `(i) -> (0)`. + if (outputMap.getValue() != + AffineMap::get(1, 0, {getAffineConstantExpr(0, op->getContext())})) + return matchFailure(); + + auto binaryOpKind = getScalarBinaryOpKind(genericOp); + if (binaryOpKind == BinaryOpKind::Unknown) + return matchFailure(); + + // Query the shader interface for local workgroup size to make sure the + // invocation configuration fits with the input memref's shape. + DenseIntElementsAttr localSize = spirv::lookupLocalWorkGroupSize(genericOp); + if (!localSize) + return matchFailure(); + + if ((*localSize.begin()).getSExtValue() != originalInputType.getDimSize(0)) + return matchFailure(); + if (llvm::any_of(llvm::drop_begin(localSize.getIntValues(), 1), + [](const APInt &size) { return !size.isOneValue(); })) + return matchFailure(); + + // TODO(antiagainst): Query the target environment to make sure the current + // workload fits in a local workgroup. + + Value convertedInput = operands[0], convertedOutput = operands[1]; + Location loc = genericOp.getLoc(); + + // Get the invocation ID. + Value x = getLocalInvocationDimSize(genericOp, /*dim=*/0, loc, &rewriter); + + // TODO(antiagainst): Load to Workgroup storage class first. + + // Get the input element accessed by this invocation. + Value inputElementPtr = spirv::getElementPtr( + typeConverter, originalInputType, convertedInput, {x}, loc, rewriter); + Value inputElement = rewriter.create(loc, inputElementPtr); + + // Perform the group reduction operation. + Value groupOperation; +#define CREATE_GROUP_NON_UNIFORM_BIN_OP(opKind, spvOp) \ + case BinaryOpKind::opKind: { \ + groupOperation = rewriter.create( \ + loc, originalInputType.getElementType(), spirv::Scope::Subgroup, \ + spirv::GroupOperation::Reduce, inputElement, \ + /*cluster_size=*/ArrayRef()); \ + } break + switch (binaryOpKind) { + CREATE_GROUP_NON_UNIFORM_BIN_OP(IAdd, GroupNonUniformIAddOp); + case BinaryOpKind::Unknown: + llvm_unreachable("failed to reject match"); + } +#undef CREATE_GROUP_NON_UNIFORM_BIN_OP + + // Get the output element accessed by this reduction. + Value zero = spirv::ConstantOp::getZero( + typeConverter.getIndexType(rewriter.getContext()), loc, &rewriter); + SmallVector zeroIndices(originalOutputType.getRank(), zero); + Value outputElementPtr = + spirv::getElementPtr(typeConverter, originalOutputType, convertedOutput, + zeroIndices, loc, rewriter); + + // Write out the final reduction result. This should be only conducted by one + // invocation. We use spv.GroupNonUniformElect to find the invocation with the + // lowest ID. + // + // ``` + // if (spv.GroupNonUniformElect) { output = ... } + // ``` + + Value condition = rewriter.create( + loc, spirv::Scope::Subgroup); + + auto createAtomicOp = [&](OpBuilder *builder) { +#define CREATE_ATOMIC_BIN_OP(opKind, spvOp) \ + case BinaryOpKind::opKind: { \ + builder->create(loc, outputElementPtr, spirv::Scope::Device, \ + spirv::MemorySemantics::AcquireRelease, \ + groupOperation); \ + } break + switch (binaryOpKind) { + CREATE_ATOMIC_BIN_OP(IAdd, AtomicIAddOp); + case BinaryOpKind::Unknown: + llvm_unreachable("failed to reject match"); + } +#undef CREATE_ATOMIC_BIN_OP + }; + + spirv::SelectionOp::createIfThen(loc, condition, createAtomicOp, &rewriter); + + rewriter.eraseOp(genericOp); + return matchSuccess(); +} + +//===----------------------------------------------------------------------===// +// Pattern population +//===----------------------------------------------------------------------===// + +void mlir::populateLinalgToSPIRVPatterns(MLIRContext *context, + SPIRVTypeConverter &typeConverter, + OwningRewritePatternList &patterns) { + patterns.insert(context, typeConverter); +} diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp @@ -0,0 +1,51 @@ +//===- LinalgToSPIRVPass.cpp - Linalg to SPIR-V conversion pass -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h" +#include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h" +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/SPIRVLowering.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +/// A pass converting MLIR Linalg ops into SPIR-V ops. +class LinalgToSPIRVPass : public ModulePass { + void runOnModule() override; +}; +} // namespace + +void LinalgToSPIRVPass::runOnModule() { + MLIRContext *context = &getContext(); + ModuleOp module = getModule(); + + SPIRVTypeConverter typeConverter; + OwningRewritePatternList patterns; + populateLinalgToSPIRVPatterns(context, typeConverter, patterns); + populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns); + + auto targetEnv = spirv::lookupTargetEnvOrDefault(module); + std::unique_ptr target = + spirv::SPIRVConversionTarget::get(targetEnv, context); + + // Allow builtin ops. + target->addLegalOp(); + target->addDynamicallyLegalOp( + [&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); }); + + if (failed(applyFullConversion(module, *target, patterns))) + return signalPassFailure(); +} + +std::unique_ptr> mlir::createLinalgToSPIRVPass() { + return std::make_unique(); +} + +static PassRegistration + pass("convert-linalg-to-spirv", "Convert Linalg ops to SPIR-V ops"); diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -141,48 +141,6 @@ } // namespace -//===----------------------------------------------------------------------===// -// Utility functions for operation conversion -//===----------------------------------------------------------------------===// - -/// Performs the index computation to get to the element pointed to by -/// `indices` using the layout map of `baseType`. - -// TODO(ravishankarm) : This method assumes that the `origBaseType` is a -// MemRefType with AffineMap that has static strides. Handle dynamic strides -static spirv::AccessChainOp getElementPtr(OpBuilder &builder, - SPIRVTypeConverter &typeConverter, - Location loc, MemRefType origBaseType, - Value basePtr, - ArrayRef indices) { - // Get base and offset of the MemRefType and verify they are static. - int64_t offset; - SmallVector strides; - if (failed(getStridesAndOffset(origBaseType, strides, offset)) || - llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) { - return nullptr; - } - - auto indexType = typeConverter.getIndexType(builder.getContext()); - - Value ptrLoc = nullptr; - assert(indices.size() == strides.size()); - for (auto index : enumerate(indices)) { - Value strideVal = builder.create( - loc, indexType, IntegerAttr::get(indexType, strides[index.index()])); - Value update = builder.create(loc, strideVal, index.value()); - ptrLoc = - (ptrLoc ? builder.create(loc, ptrLoc, update).getResult() - : update); - } - SmallVector linearizedIndices; - // Add a '0' at the start to index into the struct. - linearizedIndices.push_back(builder.create( - loc, indexType, IntegerAttr::get(indexType, 0))); - linearizedIndices.push_back(ptrLoc); - return builder.create(loc, basePtr, linearizedIndices); -} - //===----------------------------------------------------------------------===// // ConstantOp with composite type. //===----------------------------------------------------------------------===// @@ -331,9 +289,9 @@ LoadOpConversion::matchAndRewrite(LoadOp loadOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { LoadOpOperandAdaptor loadOperands(operands); - auto loadPtr = getElementPtr(rewriter, typeConverter, loadOp.getLoc(), - loadOp.memref().getType().cast(), - loadOperands.memref(), loadOperands.indices()); + auto loadPtr = spirv::getElementPtr( + typeConverter, loadOp.memref().getType().cast(), + loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter); rewriter.replaceOpWithNewOp(loadOp, loadPtr); return matchSuccess(); } @@ -374,10 +332,10 @@ StoreOpConversion::matchAndRewrite(StoreOp storeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { StoreOpOperandAdaptor storeOperands(operands); - auto storePtr = - getElementPtr(rewriter, typeConverter, storeOp.getLoc(), - storeOp.memref().getType().cast(), - storeOperands.memref(), storeOperands.indices()); + auto storePtr = spirv::getElementPtr( + typeConverter, storeOp.memref().getType().cast(), + storeOperands.memref(), storeOperands.indices(), storeOp.getLoc(), + rewriter); rewriter.replaceOpWithNewOp(storeOp, storePtr, storeOperands.value()); return matchSuccess(); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -292,6 +292,41 @@ return builder.create(op->getLoc(), ptr); } +//===----------------------------------------------------------------------===// +// Index calculation +//===----------------------------------------------------------------------===// + +spirv::AccessChainOp mlir::spirv::getElementPtr( + SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, + ArrayRef indices, Location loc, OpBuilder &builder) { + // Get base and offset of the MemRefType and verify they are static. + int64_t offset; + SmallVector strides; + if (failed(getStridesAndOffset(baseType, strides, offset)) || + llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) { + return nullptr; + } + + auto indexType = typeConverter.getIndexType(builder.getContext()); + + Value ptrLoc = nullptr; + assert(indices.size() == strides.size()); + for (auto index : enumerate(indices)) { + Value strideVal = builder.create( + loc, indexType, IntegerAttr::get(indexType, strides[index.index()])); + Value update = builder.create(loc, strideVal, index.value()); + ptrLoc = + (ptrLoc ? builder.create(loc, ptrLoc, update).getResult() + : update); + } + SmallVector linearizedIndices; + // Add a '0' at the start to index into the struct. + linearizedIndices.push_back(builder.create( + loc, indexType, IntegerAttr::get(indexType, 0))); + linearizedIndices.push_back(ptrLoc); + return builder.create(loc, basePtr, linearizedIndices); +} + //===----------------------------------------------------------------------===// // Set ABI attributes for lowering entry functions. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -2755,6 +2755,38 @@ builder.create(getLoc()); } +spirv::SelectionOp spirv::SelectionOp::createIfThen( + Location loc, Value condition, + llvm::function_ref thenBody, OpBuilder *builder) { + auto selectionControl = builder->getI32IntegerAttr( + static_cast(spirv::SelectionControl::None)); + auto selectionOp = builder->create(loc, selectionControl); + + selectionOp.addMergeBlock(); + Block *mergeBlock = selectionOp.getMergeBlock(); + Block *thenBlock = nullptr; + + // Build the "then" block. + { + OpBuilder::InsertionGuard guard(*builder); + thenBlock = builder->createBlock(mergeBlock); + thenBody(builder); + builder->create(loc, mergeBlock); + } + + // Build the header block. + { + OpBuilder::InsertionGuard guard(*builder); + builder->createBlock(thenBlock); + builder->create( + loc, condition, thenBlock, + /*trueArguments=*/ArrayRef(), mergeBlock, + /*falseArguments=*/ArrayRef()); + } + + return selectionOp; +} + namespace { // Blocks from the given `spv.selection` operation must satisfy the following // layout: diff --git a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp --- a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp +++ b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/SPIRV/TargetAndABI.h" #include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/FunctionSupport.h" #include "mlir/IR/Operation.h" using namespace mlir; @@ -62,3 +63,16 @@ return attr; return getDefaultTargetEnv(op->getContext()); } + +DenseIntElementsAttr spirv::lookupLocalWorkGroupSize(Operation *op) { + while (op && !op->hasTrait()) + op = op->getParentOp(); + if (!op) + return {}; + + if (auto attr = op->getAttrOfType( + spirv::getEntryPointABIAttrName())) + return attr.local_size(); + + return {}; +} diff --git a/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir b/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir @@ -0,0 +1,162 @@ +// RUN: mlir-opt -split-input-file -convert-linalg-to-spirv -canonicalize -verify-diagnostics %s -o - | FileCheck %s + +//===----------------------------------------------------------------------===// +// Single workgroup reduction +//===----------------------------------------------------------------------===// + +#single_workgroup_reduction_trait = { + args_in = 1, + args_out = 1, + iterator_types = ["reduction"], + indexing_maps = [ + affine_map<(i) -> (i)>, + affine_map<(i) -> (0)> + ] +} + +module attributes { + spv.target_env = { + version = 3 : i32, + extensions = [], + capabilities = [1: i32, 63: i32] // Shader, GroupNonUniformArithmetic + } +} { + +// CHECK: spv.globalVariable +// CHECK-SAME: built_in("LocalInvocationId") + +// CHECK: func @single_workgroup_reduction +// CHECK-SAME: (%[[INPUT:.+]]: !spv.ptr{{.+}}, %[[OUTPUT:.+]]: !spv.ptr{{.+}}) + +// CHECK: %[[ZERO:.+]] = spv.constant 0 : i32 +// CHECK: %[[ID:.+]] = spv.Load "Input" %{{.+}} : vector<3xi32> +// CHECK: %[[X:.+]] = spv.CompositeExtract %[[ID]][0 : i32] + +// CHECK: %[[INPTR:.+]] = spv.AccessChain %[[INPUT]][%[[ZERO]], %[[X]]] +// CHECK: %[[VAL:.+]] = spv.Load "StorageBuffer" %[[INPTR]] : i32 +// CHECK: %[[ADD:.+]] = spv.GroupNonUniformIAdd "Subgroup" "Reduce" %[[VAL]] : i32 + +// CHECK: %[[OUTPTR:.+]] = spv.AccessChain %[[OUTPUT]][%[[ZERO]], %[[ZERO]]] +// CHECK: %[[ELECT:.+]] = spv.GroupNonUniformElect "Subgroup" : i1 + +// CHECK: spv.selection { +// CHECK: spv.BranchConditional %[[ELECT]], ^bb1, ^bb2 +// CHECK: ^bb1: +// CHECK: spv.AtomicIAdd "Device" "AcquireRelease" %[[OUTPTR]], %[[ADD]] +// CHECK: spv.Branch ^bb2 +// CHECK: ^bb2: +// CHECK: spv._merge +// CHECK: } +// CHECK: spv.Return + +func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>) attributes { + spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>} +} { + linalg.generic #single_workgroup_reduction_trait %input, %output { + ^bb(%in: i32, %out: i32): + %sum = addi %in, %out : i32 + linalg.yield %sum : i32 + } : memref<16xi32>, memref<1xi32> + spv.Return +} +} + +// ----- + +// Missing shader entry point ABI + +#single_workgroup_reduction_trait = { + args_in = 1, + args_out = 1, + iterator_types = ["reduction"], + indexing_maps = [ + affine_map<(i) -> (i)>, + affine_map<(i) -> (0)> + ] +} + +module attributes { + spv.target_env = { + version = 3 : i32, + extensions = [], + capabilities = [1: i32, 63: i32] // Shader, GroupNonUniformArithmetic + } +} { +func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>) { + // expected-error @+1 {{failed to legalize operation 'linalg.generic'}} + linalg.generic #single_workgroup_reduction_trait %input, %output { + ^bb(%in: i32, %out: i32): + %sum = addi %in, %out : i32 + linalg.yield %sum : i32 + } : memref<16xi32>, memref<1xi32> + return +} +} + +// ----- + +// Mismatch between shader entry point ABI and input memref shape + +#single_workgroup_reduction_trait = { + args_in = 1, + args_out = 1, + iterator_types = ["reduction"], + indexing_maps = [ + affine_map<(i) -> (i)>, + affine_map<(i) -> (0)> + ] +} + +module attributes { + spv.target_env = { + version = 3 : i32, + extensions = [], + capabilities = [1: i32, 63: i32] // Shader, GroupNonUniformArithmetic + } +} { +func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>) attributes { + spv.entry_point_abi = {local_size = dense<[32, 1, 1]>: vector<3xi32>} +} { + // expected-error @+1 {{failed to legalize operation 'linalg.generic'}} + linalg.generic #single_workgroup_reduction_trait %input, %output { + ^bb(%in: i32, %out: i32): + %sum = addi %in, %out : i32 + linalg.yield %sum : i32 + } : memref<16xi32>, memref<1xi32> + spv.Return +} +} + +// ----- + +// Unsupported multi-dimension input memref + +#single_workgroup_reduction_trait = { + args_in = 1, + args_out = 1, + iterator_types = ["parallel", "reduction"], + indexing_maps = [ + affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (i)> + ] +} + +module attributes { + spv.target_env = { + version = 3 : i32, + extensions = [], + capabilities = [1: i32, 63: i32] // Shader, GroupNonUniformArithmetic + } +} { +func @single_workgroup_reduction(%input: memref<16x8xi32>, %output: memref<16xi32>) attributes { + spv.entry_point_abi = {local_size = dense<[16, 8, 1]>: vector<3xi32>} +} { + // expected-error @+1 {{failed to legalize operation 'linalg.generic'}} + linalg.generic #single_workgroup_reduction_trait %input, %output { + ^bb(%in: i32, %out: i32): + %sum = addi %in, %out : i32 + linalg.yield %sum : i32 + } : memref<16x8xi32>, memref<16xi32> + spv.Return +} +} diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -40,6 +40,7 @@ MLIRQuantOps MLIRROCDLIR MLIRSPIRV + MLIRLinalgToSPIRVTransforms MLIRStandardToSPIRVTransforms MLIRSPIRVTestPasses MLIRSPIRVTransforms