-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[uArch][XeGPU] Add XeGPU uArch definition. #153706
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir Author: Md Abdullah Shahneous Bari (mshahneo) ChangesThe uArch infrastructure provides:
Add support for PVC and BMG architectures. Patch is 32.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/153706.diff 12 Files Affected:
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
new file mode 100644
index 0000000000000..9179838f8c148
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -0,0 +1,182 @@
+//===--- IntelGpuXe2.h ---------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// Xe2 uArch definition.
+///
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_XEGPU_UARCH_INTEL_GPU_XE2_H
+#define MLIR_DIALECT_XEGPU_UARCH_INTEL_GPU_XE2_H
+
+#include "mlir/Dialect/XeGPU/uArch/uArchInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include <map>
+#include <string>
+#include <vector>
+
+namespace mlir {
+namespace xegpu {
+namespace uArch {
+namespace Xe2Plus {
+struct XeCoreInfo {
+ uint32_t num_threads;
+ SharedMemory shared_memory;
+ uint32_t num_vector_units;
+ uint32_t num_matrix_units;
+
+ // Constructor
+ XeCoreInfo(uint32_t num_threads, const SharedMemory &shared_memory,
+ uint32_t num_vector_units, uint32_t num_matrix_units)
+ : num_threads(num_threads), shared_memory(shared_memory),
+ num_vector_units(num_vector_units), num_matrix_units(num_matrix_units) {
+ }
+};
+
+struct Xe2Plus : public uArch {
+ XeCoreInfo xe_core;
+ Xe2Plus(
+ const std::string &archName, const std::string &archDescription,
+ const XeCoreInfo &xeCore,
+ const std::vector<uArchHierarchyComponent> &hierarchy = {},
+ const std::map<std::string, RegisterFileInfo> ®Info = {},
+ const std::vector<CacheInfo> &cacheInfo = {},
+ const std::map<std::string, std::shared_ptr<Instruction>> &instrs = {})
+ : uArch(archName, archDescription, hierarchy, regInfo, cacheInfo, instrs),
+ xe_core(xeCore) {}
+};
+
+// struct to represent DPAS instruction
+struct DPASInstruction : public Instruction, public MMAInstructionInterface {
+ DPASInstruction()
+ : Instruction("dpas", // name
+ "Dot Product Accumulate") // description
+ {}
+
+ // Override all virtuals from MatrixOpInterface
+ virtual std::vector<std::pair<uint32_t, uint32_t>>
+ getSupportedShapes(mlir::Type dataType, MMAOpndEnum matrixType) override;
+ virtual std::vector<mlir::Type>
+ getSupportedTypes(MLIRContext &context, MMAOpndEnum matrixType) override;
+ virtual bool
+ checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
+ std::pair<uint32_t, uint32_t> BShape,
+ std::pair<uint32_t, uint32_t> CShape,
+ std::pair<uint32_t, uint32_t> DShape,
+ mlir::Type AType, mlir::Type BType,
+ mlir::Type CType, mlir::Type DType) override;
+ virtual bool checkSupportedTypes(mlir::Type AType, mlir::Type BType,
+ mlir::Type CType, mlir::Type DType) override;
+ virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
+ std::pair<uint32_t, uint32_t> BShape,
+ std::pair<uint32_t, uint32_t> CShape,
+ std::pair<uint32_t, uint32_t> DShape, mlir::Type AType,
+ mlir::Type BType, mlir::Type CType,
+ mlir::Type DType) override;
+ virtual std::vector<uint32_t> getSupportedM(mlir::Type type) override;
+ virtual std::vector<uint32_t> getSupportedK(mlir::Type type) override;
+ virtual std::vector<uint32_t> getSupportedN(mlir::Type type) override;
+};
+
+namespace PVCuArch {
+struct PVCuArch : public Xe2Plus {
+ // Maintaines ownership of the instructions owned by PVUarch
+ std::vector<std::shared_ptr<Instruction>> owned_instructions;
+ PVCuArch()
+ : Xe2Plus("pvc", // archName
+ "Ponte Vecchio Architecture", // archDescription
+ XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8), // xeCore
+ {/* register_file_info */}, // Optional: empty
+ {/* cache_info */}, // Optional: empty
+ {/* instructions */} // Optional: empty
+ ) {
+ // Initialize uArchHierarchy
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("thread", 0));
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeCore", 8));
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeSlice", 16));
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeStack", 4));
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("gpu", 2));
+ // Intialize register file info
+ // GRF
+ this->register_file_info.emplace(
+ "GRF",
+ RegisterFileInfo(64 * 1024, // size in bits
+ {"small", "large"}, // GRF modes
+ {128, 256}, // registers per thread per mode
+ 0, // number of banks
+ 0 // bank size
+ ));
+ // Initialize cache info
+ // L1 cache, XeCore level
+ this->cache_info.push_back(
+ CacheInfo(512 * 1024, 64, this->uArch_hierarchy[1]));
+ // L3 cache, XeStack level
+ this->cache_info.push_back(
+ CacheInfo(512 * 1024, 64, this->uArch_hierarchy[3]));
+
+ // Add the instructions
+ auto dpas = std::make_shared<DPASInstruction>();
+ instructions.emplace(dpas->getName(), dpas);
+ // instructions[dpas->name] = dpas.get();
+ owned_instructions.push_back(dpas);
+ }
+};
+} // namespace PVCuArch
+
+namespace BMGuArch {
+struct BMGuArch : public Xe2Plus {
+ // Maintaines ownership of the instructions owned by PVUarch
+ std::vector<std::shared_ptr<Instruction>> owned_instructions;
+ BMGuArch()
+ : Xe2Plus("bmg", // archName
+ "Battlemage Architecture", // archDescription
+ XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8), // xeCore
+ {/* register_file_info */}, // Optional: empty
+ {/* cache_info */}, // Optional: empty
+ {/* instructions */}, // Optional: empty
+ {/* restrictions */} // Optional: empty
+ ) {
+ // Initialize uArchHierarchy
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("thread", 0));
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeCore", 8));
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeSlice", 4));
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeStack", 5));
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("gpu", 1));
+ // Intialize register file info
+ // GRF
+ this->register_file_info["GRF"] =
+ RegisterFileInfo(64 * 1024, // size in bits
+ {"small", "large"}, // GRF modes
+ {128, 256}, // registers per thread per mode
+ 0, // number of banks
+ 0 // bank size
+ );
+ // Initialize cache info
+ // L1 cache, XeCore level
+ this->cache_info.push_back(
+ CacheInfo(256 * 1024, 64, this->uArch_hierarchy[1]));
+ // L3 cache, XeStack level
+ this->cache_info.push_back(
+ CacheInfo(18 * 1024 * 1024, 256, this->uArch_hierarchy[3]));
+
+ // Add the instructions
+ auto dpas = std::make_shared<DPASInstruction>();
+ instructions.emplace(dpas->getName(), dpas);
+ // instructions[dpas->name] = dpas.get();
+ owned_instructions.push_back(dpas);
+ }
+};
+} // namespace BMGuArch
+
+} // namespace Xe2Plus
+} // namespace uArch
+} // namespace xegpu
+} // namespace mlir
+
+#endif // MLIR_DIALECT_XEGPU_UARCH_INTEL_GPU_XE2_H
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
new file mode 100644
index 0000000000000..9bda86df2aff9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -0,0 +1,266 @@
+//===--- uArch.h ---------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// Base uArch definition for different architectures.
+///
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_XEGPU_UARCH_BASE_H
+#define MLIR_DIALECT_XEGPU_UARCH_BASE_H
+
+#include <any>
+#include <functional>
+#include <iostream>
+#include <map>
+#include <mutex>
+#include <shared_mutex>
+#include <tuple>
+
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+namespace xegpu {
+namespace uArch {
+// Architecture HW component hierarchy to present thread, core, socket ...
+struct uArchHierarchyComponent {
+ std::string name = ""; // optional name of the hierarchy component
+ // no. of lower hierarchy component it contains, e.g., for PVC XeCore it
+ // contains 8 threads, so no_of_component=8
+ uint32_t no_of_component;
+ // Constructor
+ uArchHierarchyComponent(const std::string &name, uint32_t no_of_component)
+ : name(name), no_of_component(no_of_component) {}
+};
+
+// An enum class to represent the scope of an instruction
+enum class InstructionScopeEnum { WorkItem, Subgroup, Workgroup, Cluster };
+
+// A struct to represent basic information about an instruction
+// This struct is used to represent the information about an instruction in the
+// uArch The information includes:
+// - the name of the instruction,
+// - the description of the instruction
+// - the scope of the instruction,
+//
+// The information is represented as strings
+// For example, the information about an instruction can be represented as:
+// Instruction instr = {"dpas", "Dot Product Accumulate Systolic (DPAS) is a
+// matrix multiply-add operation", "subgroup"};
+
+// The primary purpose of the Instruction struct is to provide a generic way to
+// represent information about an instruction and to use this information to
+// generate the uArch. Specifc instruction in a uArch can inherit from this
+// struct and add more fields as needed
+
+struct Instruction {
+ // @TODO: Add more fields as needed
+ Instruction(std::string name, std::string desc)
+ : name(std::move(name)), description(std::move(desc)) {}
+
+ virtual ~Instruction() = default;
+ // Get methods
+ std::string getName() { return name; }
+ std::string getDescription() { return description; }
+ InstructionScopeEnum getScope() { return scope; }
+
+protected:
+ std::string name;
+ std::string description;
+ InstructionScopeEnum scope;
+};
+
+// A struct to represent register file information
+struct RegisterFileInfo {
+ // Constructor
+ RegisterFileInfo() = default;
+ RegisterFileInfo(uint32_t size, const std::vector<std::string> &mode,
+ const std::vector<uint32_t> &numRegs, uint32_t num_banks,
+ uint32_t bank_size)
+ : size(size), mode(mode), num_regs_per_thread_per_mode(numRegs),
+ num_banks(num_banks), bank_size(bank_size) {}
+
+ // Get methods
+ uint32_t getSize() const { return size; }
+
+ const std::vector<std::string> &getModes() const { return mode; }
+
+ const std::vector<uint32_t> &getNumRegsPerThreadPerMode() const {
+ return num_regs_per_thread_per_mode;
+ }
+
+ uint32_t getNumBanks() const { return num_banks; }
+
+ uint32_t getBankSize() const { return bank_size; }
+
+protected:
+ uint32_t size; // size per register in bits
+ std::vector<std::string> mode; // e.g., "small", "large" GRF modes
+ std::vector<uint32_t>
+ num_regs_per_thread_per_mode; // number of registers per thread per mode
+ uint32_t num_banks;
+ uint32_t bank_size;
+};
+
+// A struct to represent cache information
+
+struct CacheInfo {
+ // Constructor
+ CacheInfo(uint32_t size, uint32_t line_size,
+ const uArchHierarchyComponent &component)
+ : size(size), line_size(line_size), component(component) {}
+
+ virtual ~CacheInfo() = default;
+
+ // Get methods
+ uint32_t getSize() const { return size; }
+ uint32_t getLineSize() const { return line_size; }
+ const uArchHierarchyComponent &getComponent() const { return component; }
+
+protected:
+ uint32_t size;
+ uint32_t line_size;
+ // At which component level the cache is shared
+ uArchHierarchyComponent component;
+
+ // @TODO: Add more fields as needed (e.g., associativity, num_banks,
+ // bank_size, num_ports, port_width, bank_conflicts)
+};
+
+// A struct to represent the uArch
+// This struct is used to represent the microarchitecture of a target device
+// The uArch includes:
+// - the name of the uArch,
+// - the description of the uArch,
+// - uArch hierarchy
+// - Rgister File information
+// - Cache information
+// - the set of instructions supported by the uArch,
+struct uArch {
+ // Constructor
+ uArch() = default;
+ uArch(const std::string &name, const std::string &description,
+ const std::vector<uArchHierarchyComponent> &uArch_hierarchy = {},
+ const std::map<std::string, RegisterFileInfo> ®ister_file_info = {},
+ const std::vector<CacheInfo> &cache_info = {},
+ const std::map<std::string, std::shared_ptr<Instruction>>
+ &instructions = {})
+ : name(name), description(description), uArch_hierarchy(uArch_hierarchy),
+ register_file_info(register_file_info), cache_info(cache_info),
+ instructions(instructions) {}
+
+ // Get methods
+ const std::string &getName() const { return name; }
+
+ const std::string &getDescription() const { return description; }
+
+ const std::vector<uArchHierarchyComponent> &getHierarchy() const {
+ return uArch_hierarchy;
+ }
+
+ const std::map<std::string, RegisterFileInfo> &getRegisterFileInfo() const {
+ return register_file_info;
+ }
+
+ const std::vector<CacheInfo> &getCacheInfo() const { return cache_info; }
+
+ const std::map<std::string, std::shared_ptr<Instruction>> &
+ getInstructions() const {
+ return instructions;
+ }
+
+ // Get the name of the supported instruction names for that
+ // architecture. It returns the names of the instructions added to the uArch.
+ std::vector<std::string> getSupportedInstructionNames() const {
+ std::vector<std::string> instructionNames;
+ for (const auto &inst : instructions) {
+ instructionNames.push_back(inst.first);
+ }
+ return instructionNames;
+ }
+
+ // Checks if an instruction is supported in this uArch
+ bool checkSupportedInstruction(const std::string &instructionName) const {
+ return instructions.find(instructionName) != instructions.end();
+ }
+
+protected:
+ std::string name; // Similar to target triple
+ std::string description;
+ std::vector<uArchHierarchyComponent> uArch_hierarchy;
+ std::map<std::string, RegisterFileInfo> register_file_info;
+ std::vector<CacheInfo> cache_info;
+ std::map<std::string, std::shared_ptr<Instruction>> instructions;
+};
+
+// A struct to represent shared memory information
+struct SharedMemory {
+ // Constructor
+ SharedMemory(uint32_t size, uint32_t alignment)
+ : size(size), alignment(alignment) {}
+
+ // Getters
+ uint32_t getSize() const { return size; }
+ uint32_t getAlignment() const { return alignment; }
+
+protected:
+ uint32_t size; // in bytes
+ uint32_t alignment; // in bytes
+ // @TODO: Add more fields as needed (e.g., latency, throughput, bandwidth)
+};
+
+struct uArchMap {
+public:
+ // Singleton instance
+ static uArchMap &instance() {
+ static uArchMap instance;
+ return instance;
+ }
+
+ // Insert or update a key-value pair
+ void insert(const std::string &key, std::shared_ptr<uArch> value) {
+ std::unique_lock<std::shared_mutex> lock(mutex_);
+ // map_[key] = std::move(value); // safe to overwrite
+ map_.emplace(key, std::move(value)); // safe to overwrite
+ }
+
+ // Get a value by key (concurrent safe read)
+ std::shared_ptr<uArch> get(const std::string &key) const {
+ std::shared_lock<std::shared_mutex> lock(mutex_);
+ auto it = map_.find(key);
+ if (it != map_.end())
+ return it->second;
+ return nullptr;
+ }
+
+ // Check if a key exists
+ bool contains(const std::string &key) const {
+ std::shared_lock<std::shared_mutex> lock(mutex_);
+ return map_.find(key) != map_.end();
+ }
+
+ // Remove a key
+ bool erase(const std::string &key) {
+ std::unique_lock<std::shared_mutex> lock(mutex_);
+ return map_.erase(key) > 0;
+ }
+
+private:
+ uArchMap() = default;
+ uArchMap(const uArchMap &) = delete;
+ uArchMap &operator=(const uArchMap &) = delete;
+
+ mutable std::shared_mutex mutex_;
+ std::map<std::string, std::shared_ptr<uArch>> map_;
+};
+
+} // namespace uArch
+} // namespace xegpu
+} // namespace mlir
+
+#endif // MLIR_DIALECT_XEGPU_UARCH_BASE_H
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h
new file mode 100644
index 0000000000000..27d44c38317a1
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h
@@ -0,0 +1,75 @@
+//===--- uArchInterfaces.h ---*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// Defines the utility interfaces that are implemented by individual
+/// instructions.
+///
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_XEGPU_UARCH_INTERFACES_H
+#define MLIR_DIALECT_XEGPU_UARCH_INTERFACES_H
+
+#include "mlir/Dialect/XeGPU/uArch/uArchBase.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include <map>
+#include <string>
+#include <vector>
+
+namespace mlir {
+namespace xegpu {
+namespace uArch {
+
+enum class MMAOpndEnum { MatrixA, MatrixB, MatrixC, MatrixD };
+struct MMAInstructionInterface {
+ // Get supported Matrix shapes
+ virtual std::vector<std::pair<uint32_t, uint32_t>>
+ getSupportedShapes(mlir::Type dataType, MMAOpndEnum matrixType) = 0;
+
+ // @TODO: This method takes an context object as a parameter, this is to
+ // create the mlir::Type objects from the same context. Since type objects are
+ // uniqued in a specific context, to do things like "aType == bType" (where
+ // aType and bType are both same type) kind of checks, the both types should
+ // be from the same context.
+ //
+ // One alternative to this is to create enum to represent each types, but this
+ // adds an extra burden to user to convert these enums to specific types. In
+ // fact the utility that would convert enumToType() and vice versa would still
+ // have to use the context object.
+ //
+ // Untill we have a better solution, we stick to passing context object to
+ // this method.
+ virtual std::vector<mlir::Type> getSupportedTypes(MLIRContext &context,
+ MMAOpndEnum matrixType) = 0;
+ virtual bool
+ checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
+ std::pair<uint32_t, uint32_t> BShape,
+ std::pair<uint32_t, uint32_t> CShape,
+ std::pair<uint32_t, uint32_t> DShape,
+ mlir::Type AType, mlir::Type BType,
+ mlir::Type CType, mlir::Type DType) = 0;
+ virtual bool checkSupportedTypes(mlir::Type AType, mlir::Type BType,
+ mlir::Type CType, mlir::Type DType) = 0;
+ virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
+ std::pair<uint32_t, uint32_t> BShape,
+ std::pair<uint32_t, uint32_t> CShape,
+ ...
[truncated]
|
The uArch infrastructure provides: - A set data structures to represent, uArch and it's necessary components (e.g., instructions, register-files, caches). - A set of utility interfaces that are common to a family of ops (e.g., mma ops, 2DBlockIO ops). The implementation of these interfaces are provided by the specific instructions. Each family of ops provides these 5 common APIs. However, some family of ops may have more utility APIs. The common 5 APIs are: - getSupportedShapes - getSupportedTypes - checkSupportedShapesAndTypes - checkSupportedTypes - validate Add support for PVC and BMG architectures. Add support for DPAS instruction.
ed85437
to
b1d37c0
Compare
High-level comment: could the "family of ops" be represented by |
Hi Rolf, Thank you for bringing it up. I don't have any opposition to this in principle. However, I have a few concerns/discussion points,
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
High-level question: is there any info that gets pulled at runtime?
I imagine most of these hardware properties don't really change. Can't all this be static compile time info?
//===----------------------------------------------------------------------===// | ||
// | ||
/// \file | ||
/// Xe2 uArch definition. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
//
within the whole header
This section could be either removed or expanded to add more context, links to relevant docs etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, addressed.
@@ -0,0 +1,182 @@ | |||
//===--- IntelGpuXe2.h ---------------------------------------*- C++ -*-===// |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
//===--- IntelGpuXe2.h ---------------------------------------*- C++ -*-===// | |
//===--- IntelGpuXe2.h ------------------------------------------*- C++ -*-===// |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
/// | ||
// | ||
//===----------------------------------------------------------------------===// | ||
#ifndef MLIR_DIALECT_XEGPU_UARCH_INTEL_GPU_XE2_H |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#ifndef MLIR_DIALECT_XEGPU_UARCH_INTEL_GPU_XE2_H | |
#ifndef MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks!
/// | ||
// | ||
//===----------------------------------------------------------------------===// | ||
#ifndef MLIR_DIALECT_XEGPU_UARCH_BASE_H |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#ifndef MLIR_DIALECT_XEGPU_UARCH_BASE_H | |
#ifndef MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H |
Full file name in the guard
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -0,0 +1,197 @@ | |||
#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h" | |||
#include "mlir/IR/BuiltinTypes.h" | |||
#include "llvm/Support/YAMLTraits.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this header is relevant
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
#include "mlir/IR/BuiltinTypes.h" | ||
#include "llvm/Support/YAMLTraits.h" | ||
#include <algorithm> | ||
#include <iostream> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this needed as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
if (AType.isF16() || BType.isF16()) { | ||
if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) || | ||
(!DType.isF32() && !DType.isF16())) { | ||
llvm::errs() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see using this helper as a part of pass matcher. I definitely don't want to get spammed with errors 😉
Overall, these message are really verbose and I'm not sure if it's that useful.
Maybe a table of all supported combination could a part of function docs (source or header)?
A shorter error could be hidden under debug LDBG() << "msg"
virtual std::vector<uint32_t> getSupportedN(mlir::Type type) override; | ||
}; | ||
|
||
namespace PVCuArch { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think we can skip this extra nested namespace
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was actually keeping this in case we decide to make specific version of uArchs (i.e. different version of BMG).
}; | ||
|
||
namespace PVCuArch { | ||
struct PVCuArch : public Xe2Plus { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't all this whole class be static?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the set up of the uArch map uses an instance of uArch, I kept it as non-static.
@@ -35,6 +36,14 @@ void XeGPUDialect::initialize() { | |||
#define GET_ATTRDEF_LIST | |||
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc> | |||
>(); | |||
|
|||
// Populate the uArchMap with the supported target devices |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it belong to dialect initializer.
Besides that, let's start with a simple design and let users create uArch
instances on demand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can remove this from the dialect initializer.
I kept it here because this way, the uArch map is populated once the dialect is loaded. Any pass or the dialect can use it from there on. But other option can also work, where we a pass can initialize the uArch map and use it afterwards.
Any specific reason why it can't be in dialect initializer?
Nothing in particular, all are static info. I actually tried to follow the MLIR way. Since it is like an utility, I tried to follow the utility functions in different dialects. But I am open to change. |
The uArch infrastructure provides:
Add support for PVC and BMG architectures.
Add support for DPAS instruction.